From 207c27b7f1502f3250d69caf7a71481e81de6fb4 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 11 Apr 2020 18:25:10 +0100 Subject: [PATCH 01/14] Python 3.6+ syntax --- .gitignore | 1 + distributed/client.py | 5 +- distributed/tests/test_publish.py | 116 ++++++++++++++---------------- 3 files changed, 57 insertions(+), 65 deletions(-) diff --git a/.gitignore b/.gitignore index 86ee425adf..cf6732eaa7 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ dask-worker-space/ *.swp .ycm_extra_conf.py tags +.ipynb_checkpoints diff --git a/distributed/client.py b/distributed/client.py index 6545e93851..a87a080324 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2248,11 +2248,10 @@ def list_datasets(self, **kwargs): async def _get_dataset(self, name): out = await self.scheduler.publish_get(name=name, client=self.id) if out is None: - raise KeyError("Dataset '%s' not found" % name) + raise KeyError(f"Dataset '{name}' not found") with temp_default_client(self): - data = out["data"] - return data + return out["data"] def get_dataset(self, name, **kwargs): """ diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index dde10b11cf..ee4b8e4e03 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -11,97 +11,90 @@ @gen_cluster(client=False) -def test_publish_simple(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_publish_simple(s, a, b): + c = await Client(s.address, asynchronous=True, set_as_default=False) + f = await Client(s.address, asynchronous=True, set_as_default=False) - data = yield c.scatter(range(3)) - out = yield c.publish_dataset(data=data) + data = await c.scatter(range(3)) + await c.publish_dataset(data=data) assert "data" in s.extensions["publish"].datasets assert isinstance(s.extensions["publish"].datasets["data"]["data"], Serialized) with pytest.raises(KeyError) as exc_info: - out = yield c.publish_dataset(data=data) + await c.publish_dataset(data=data) assert "exists" in str(exc_info.value) assert "data" in str(exc_info.value) - result = yield c.scheduler.publish_list() + result = await c.scheduler.publish_list() assert result == ("data",) - result = yield f.scheduler.publish_list() + result = await f.scheduler.publish_list() assert result == ("data",) - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=False) -def test_publish_non_string_key(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) - - try: +async def test_publish_non_string_key(s, a, b): + async with Client(s.address, asynchronous=True, set_as_default=False) as c: for name in [("a", "b"), 9.0, 8]: - data = yield c.scatter(range(3)) - out = yield c.publish_dataset(data, name=name) + data = await c.scatter(range(3)) + await c.publish_dataset(data, name=name) assert name in s.extensions["publish"].datasets assert isinstance( s.extensions["publish"].datasets[name]["data"], Serialized ) - datasets = yield c.scheduler.publish_list() + datasets = await c.scheduler.publish_list() assert name in datasets - finally: - yield c.close() - yield f.close() - @gen_cluster(client=False) -def test_publish_roundtrip(s, a, b): - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) +async def test_publish_roundtrip(s, a, b): + c = await Client(s.address, asynchronous=True, set_as_default=False) + f = await Client(s.address, asynchronous=True, set_as_default=False) - data = yield c.scatter([0, 1, 2]) - yield c.publish_dataset(data=data) + data = await c.scatter([0, 1, 2]) + await c.publish_dataset(data=data) assert "published-data" in s.who_wants[data[0].key] - result = yield f.get_dataset(name="data") + result = await f.get_dataset(name="data") assert len(result) == len(data) - out = yield f.gather(result) + out = await f.gather(result) assert out == [0, 1, 2] with pytest.raises(KeyError) as exc_info: - result = yield f.get_dataset(name="nonexistent") + await f.get_dataset(name="nonexistent") assert "not found" in str(exc_info.value) assert "nonexistent" in str(exc_info.value) - yield c.close() - yield f.close() + await c.close() + await f.close() @gen_cluster(client=True) -def test_unpublish(c, s, a, b): - data = yield c.scatter([0, 1, 2]) - yield c.publish_dataset(data=data) +async def test_unpublish(c, s, a, b): + data = await c.scatter([0, 1, 2]) + await c.publish_dataset(data=data) key = data[0].key del data - yield c.scheduler.publish_delete(name="data") + await c.scheduler.publish_delete(name="data") assert "data" not in s.extensions["publish"].datasets start = time() while key in s.who_wants: - yield gen.sleep(0.01) + await gen.sleep(0.01) assert time() < start + 5 with pytest.raises(KeyError) as exc_info: - result = yield c.get_dataset(name="data") + await c.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @@ -113,19 +106,19 @@ def test_unpublish_sync(client): client.unpublish_dataset(name="data") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name="data") + client.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @gen_cluster(client=True) -def test_publish_multiple_datasets(c, s, a, b): +async def test_publish_multiple_datasets(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(2) - yield c.publish_dataset(x=x, y=y) - datasets = yield c.scheduler.publish_list() + await c.publish_dataset(x=x, y=y) + datasets = await c.scheduler.publish_list() assert set(datasets) == {"x", "y"} @@ -136,7 +129,7 @@ def test_unpublish_multiple_datasets_sync(client): client.unpublish_dataset(name="x") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name="x") + client.get_dataset(name="x") datasets = client.list_datasets() assert set(datasets) == {"y"} @@ -147,17 +140,17 @@ def test_unpublish_multiple_datasets_sync(client): client.unpublish_dataset(name="y") with pytest.raises(KeyError) as exc_info: - result = client.get_dataset(name="y") + client.get_dataset(name="y") assert "not found" in str(exc_info.value) assert "y" in str(exc_info.value) @gen_cluster(client=False) -def test_publish_bag(s, a, b): +async def test_publish_bag(s, a, b): db = pytest.importorskip("dask.bag") - c = yield Client(s.address, asynchronous=True) - f = yield Client(s.address, asynchronous=True) + c = await Client(s.address, asynchronous=True, set_as_default=False) + f = await Client(s.address, asynchronous=True, set_as_default=False) bag = db.from_sequence([0, 1, 2]) bagp = c.persist(bag) @@ -166,19 +159,19 @@ def test_publish_bag(s, a, b): keys = {f.key for f in futures_of(bagp)} assert keys == set(bag.dask) - yield c.publish_dataset(data=bagp) + await c.publish_dataset(data=bagp) # check that serialization didn't affect original bag's dask assert len(futures_of(bagp)) == 3 - result = yield f.get_dataset("data") + result = await f.get_dataset("data") assert set(result.dask.keys()) == set(bagp.dask.keys()) assert {f.key for f in result.dask.values()} == {f.key for f in bagp.dask.values()} - out = yield f.compute(result) + out = await f.compute(result) assert out == [0, 1, 2] - yield c.close() - yield f.close() + await c.close() + await f.close() def test_datasets_setitem(client): @@ -223,19 +216,18 @@ def test_datasets_iter(client): @gen_cluster(client=True) -def test_pickle_safe(c, s, a, b): - c2 = yield Client(s.address, asynchronous=True, serializers=["msgpack"]) - try: - yield c2.publish_dataset(x=[1, 2, 3]) - result = yield c2.get_dataset("x") +async def test_pickle_safe(c, s, a, b): + async with Client( + s.address, asynchronous=True, serializers=["msgpack"], set_as_default=False + ) as c2: + await c2.publish_dataset(x=[1, 2, 3]) + result = await c2.get_dataset("x") assert result == [1, 2, 3] with pytest.raises(TypeError): - yield c2.publish_dataset(y=lambda x: x) + await c2.publish_dataset(y=lambda x: x) - yield c.publish_dataset(z=lambda x: x) # this can use pickle + await c.publish_dataset(z=lambda x: x) # this can use pickle with pytest.raises(TypeError): - yield c2.get_dataset("z") - finally: - yield c2.close() + await c2.get_dataset("z") From 02b1079833e0e42e7659f472ed2cc3039e333039 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 14 Apr 2020 14:05:27 +0100 Subject: [PATCH 02/14] Code polish --- distributed/client.py | 4 +-- distributed/core.py | 62 +++++++++++++++++++++--------------------- distributed/publish.py | 8 +++--- 3 files changed, 36 insertions(+), 38 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index a87a080324..53f9b7030c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2249,9 +2249,7 @@ async def _get_dataset(self, name): out = await self.scheduler.publish_get(name=name, client=self.id) if out is None: raise KeyError(f"Dataset '{name}' not found") - - with temp_default_client(self): - return out["data"] + return out["data"] def get_dataset(self, name, **kwargs): """ diff --git a/distributed/core.py b/distributed/core.py index dd5e18d000..70e0e8235c 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -673,24 +673,24 @@ async def _close_comm(comm): return tasks def __getattr__(self, key): - async def send_recv_from_rpc(**kwargs): - if self.serializers is not None and kwargs.get("serializers") is None: - kwargs["serializers"] = self.serializers - if self.deserializers is not None and kwargs.get("deserializers") is None: - kwargs["deserializers"] = self.deserializers - try: - comm = await self.live_comm() - comm.name = "rpc." + key - result = await send_recv(comm=comm, op=key, **kwargs) - except (RPCClosed, CommClosedError) as e: - raise e.__class__( - "%s: while trying to call remote method %r" % (e, key) - ) + return partial(self._send_recv_from_rpc, key) - self.comms[comm] = True # mark as open - return result + async def _send_recv_from_rpc(self, key, **kwargs): + if self.serializers is not None and kwargs.get("serializers") is None: + kwargs["serializers"] = self.serializers + if self.deserializers is not None and kwargs.get("deserializers") is None: + kwargs["deserializers"] = self.deserializers + try: + comm = await self.live_comm() + comm.name = "rpc." + key + result = await send_recv(comm=comm, op=key, **kwargs) + except (RPCClosed, CommClosedError) as e: + raise e.__class__( + "%s: while trying to call remote method %r" % (e, key) + ) - return send_recv_from_rpc + self.comms[comm] = True # mark as open + return result def close_rpc(self): if self.status != "closed": @@ -744,22 +744,22 @@ def address(self): return self.addr def __getattr__(self, key): - async def send_recv_from_rpc(**kwargs): - if self.serializers is not None and kwargs.get("serializers") is None: - kwargs["serializers"] = self.serializers - if self.deserializers is not None and kwargs.get("deserializers") is None: - kwargs["deserializers"] = self.deserializers - comm = await self.pool.connect(self.addr) - name, comm.name = comm.name, "ConnectionPool." + key - try: - result = await send_recv(comm=comm, op=key, **kwargs) - finally: - self.pool.reuse(self.addr, comm) - comm.name = name - - return result + return partial(self._send_recv_from_rpc, key) + + async def _send_recv_from_rpc(self, key, **kwargs): + if self.serializers is not None and kwargs.get("serializers") is None: + kwargs["serializers"] = self.serializers + if self.deserializers is not None and kwargs.get("deserializers") is None: + kwargs["deserializers"] = self.deserializers + comm = await self.pool.connect(self.addr) + name, comm.name = comm.name, "ConnectionPool." + key + try: + result = await send_recv(comm=comm, op=key, **kwargs) + finally: + self.pool.reuse(self.addr, comm) + comm.name = name - return send_recv_from_rpc + return result async def close_rpc(self): pass diff --git a/distributed/publish.py b/distributed/publish.py index 758e5ccc34..236c3a4707 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -6,10 +6,10 @@ class PublishExtension: """ An extension for the scheduler to manage collections - * publish-list - * publish-put - * publish-get - * publish-delete + * publish_list + * publish_put + * publish_get + * publish_delete """ def __init__(self, scheduler): From e07b9f81e34c8dd48b127d8cb1aa6af1d3446eac Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 14 Apr 2020 14:14:27 +0100 Subject: [PATCH 03/14] Revert --- distributed/core.py | 62 ++++++++++++++++++++++----------------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index 70e0e8235c..dd5e18d000 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -673,24 +673,24 @@ async def _close_comm(comm): return tasks def __getattr__(self, key): - return partial(self._send_recv_from_rpc, key) + async def send_recv_from_rpc(**kwargs): + if self.serializers is not None and kwargs.get("serializers") is None: + kwargs["serializers"] = self.serializers + if self.deserializers is not None and kwargs.get("deserializers") is None: + kwargs["deserializers"] = self.deserializers + try: + comm = await self.live_comm() + comm.name = "rpc." + key + result = await send_recv(comm=comm, op=key, **kwargs) + except (RPCClosed, CommClosedError) as e: + raise e.__class__( + "%s: while trying to call remote method %r" % (e, key) + ) - async def _send_recv_from_rpc(self, key, **kwargs): - if self.serializers is not None and kwargs.get("serializers") is None: - kwargs["serializers"] = self.serializers - if self.deserializers is not None and kwargs.get("deserializers") is None: - kwargs["deserializers"] = self.deserializers - try: - comm = await self.live_comm() - comm.name = "rpc." + key - result = await send_recv(comm=comm, op=key, **kwargs) - except (RPCClosed, CommClosedError) as e: - raise e.__class__( - "%s: while trying to call remote method %r" % (e, key) - ) + self.comms[comm] = True # mark as open + return result - self.comms[comm] = True # mark as open - return result + return send_recv_from_rpc def close_rpc(self): if self.status != "closed": @@ -744,22 +744,22 @@ def address(self): return self.addr def __getattr__(self, key): - return partial(self._send_recv_from_rpc, key) - - async def _send_recv_from_rpc(self, key, **kwargs): - if self.serializers is not None and kwargs.get("serializers") is None: - kwargs["serializers"] = self.serializers - if self.deserializers is not None and kwargs.get("deserializers") is None: - kwargs["deserializers"] = self.deserializers - comm = await self.pool.connect(self.addr) - name, comm.name = comm.name, "ConnectionPool." + key - try: - result = await send_recv(comm=comm, op=key, **kwargs) - finally: - self.pool.reuse(self.addr, comm) - comm.name = name + async def send_recv_from_rpc(**kwargs): + if self.serializers is not None and kwargs.get("serializers") is None: + kwargs["serializers"] = self.serializers + if self.deserializers is not None and kwargs.get("deserializers") is None: + kwargs["deserializers"] = self.deserializers + comm = await self.pool.connect(self.addr) + name, comm.name = comm.name, "ConnectionPool." + key + try: + result = await send_recv(comm=comm, op=key, **kwargs) + finally: + self.pool.reuse(self.addr, comm) + comm.name = name + + return result - return result + return send_recv_from_rpc async def close_rpc(self): pass From 5d2566c7a306e7835e9a8c474f667fcd8ea07a7b Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 14 Apr 2020 14:15:02 +0100 Subject: [PATCH 04/14] Polish --- distributed/tests/test_client.py | 24 ++++++--------- distributed/tests/test_scheduler.py | 1 - distributed/tests/test_security.py | 46 +++++++++++++---------------- distributed/tests/test_variable.py | 1 - distributed/tests/test_worker.py | 13 +++----- 5 files changed, 33 insertions(+), 52 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 297394631a..3cfad30370 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -152,8 +152,7 @@ def test_map(c, s, a, b): L4 = c.map(add, range(3), range(4)) results = yield c.gather(L4) - if sys.version_info[0] >= 3: - assert results == list(map(add, range(3), range(4))) + assert results == list(map(add, range(3), range(4))) def f(x, y=10): return x + y @@ -1439,9 +1438,7 @@ def test_many_submits_spread_evenly(c, s, a, b): def test_traceback(c, s, a, b): x = c.submit(div, 1, 0) tb = yield x.traceback() - - if sys.version_info[0] >= 3: - assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) + assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) @gen_cluster(client=True) @@ -1468,12 +1465,11 @@ def test_gather_traceback(c, s, a, b): def test_traceback_sync(c): x = c.submit(div, 1, 0) tb = x.traceback() - if sys.version_info[0] >= 3: - assert any( - "x / y" in line - for line in concat(traceback.extract_tb(tb)) - if isinstance(line, str) - ) + assert any( + "x / y" in line + for line in concat(traceback.extract_tb(tb)) + if isinstance(line, str) + ) y = c.submit(inc, x) tb2 = y.traceback() @@ -2597,9 +2593,8 @@ def test_run_coroutine(c, s, a, b): with pytest.raises(RuntimeError, match="hello"): yield c.run(throws, 1) - if sys.version_info >= (3, 5): - results = yield c.run(asyncinc, 2, delay=0.01) - assert results == {a.address: 3, b.address: 3} + results = yield c.run(asyncinc, 2, delay=0.01) + assert results == {a.address: 3, b.address: 3} def test_run_coroutine_sync(c, s, a, b): @@ -5212,7 +5207,6 @@ def test_scatter_direct(s, a, b): yield c.close() -@pytest.mark.skipif(sys.version_info[0] < 3, reason="cloudpickle Py27 issue") @gen_cluster(client=True) def test_unhashable_function(c, s, a, b): d = {"a": 1} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2b48fa030e..7551ae7f09 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -794,7 +794,6 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): @pytest.mark.skipif( sys.platform.startswith("win"), reason="file descriptors not really a thing" ) -@pytest.mark.skipif(sys.version_info < (3, 6), reason="intermittent failure") @gen_cluster(client=True, nthreads=[], timeout=240) def test_file_descriptors(c, s): yield gen.sleep(0.1) diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index 8665ebead3..dae035449a 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -149,15 +149,15 @@ def test_tls_config_for_role(): sec.get_tls_config_for_role("supervisor") +def assert_many_ciphers(ctx): + assert len(ctx.get_ciphers()) > 2 # Most likely + + def test_connection_args(): def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False - def many_ciphers(ctx): - if sys.version_info >= (3, 6): - assert len(ctx.get_ciphers()) > 2 # Most likely - c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, @@ -171,12 +171,12 @@ def many_ciphers(ctx): assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) d = sec.get_connection_args("worker") ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) # No cert defined => no TLS d = sec.get_connection_args("client") @@ -193,13 +193,12 @@ def many_ciphers(ctx): assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - if sys.version_info >= (3, 6): - supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] - assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] - if len(tls_13_ciphers): - assert len(tls_13_ciphers) == 3 + + supported_ciphers = ctx.get_ciphers() + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] + assert len(tls_12_ciphers) == 1 + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] + assert len(tls_13_ciphers) in (0, 3) def test_listen_args(): @@ -207,10 +206,6 @@ def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False - def many_ciphers(ctx): - if sys.version_info >= (3, 6): - assert len(ctx.get_ciphers()) > 2 # Most likely - c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, @@ -224,12 +219,12 @@ def many_ciphers(ctx): assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) d = sec.get_listen_args("worker") ctx = d["ssl_context"] basic_checks(ctx) - many_ciphers(ctx) + assert_many_ciphers(ctx) # No cert defined => no TLS d = sec.get_listen_args("client") @@ -246,13 +241,12 @@ def many_ciphers(ctx): assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - if sys.version_info >= (3, 6): - supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] - assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] - if len(tls_13_ciphers): - assert len(tls_13_ciphers) == 3 + + supported_ciphers = ctx.get_ciphers() + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] + assert len(tls_12_ciphers) == 1 + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] + assert len(tls_13_ciphers) in (0, 3) @pytest.mark.asyncio diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index 64765d808c..b8aaa9275c 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -161,7 +161,6 @@ def test_timeout_get(c, s, a, b): assert result == 1 -@pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index a5e364ec0c..42cd1bf0c9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -136,11 +136,10 @@ def reset(self): with pytest.raises(ZeroDivisionError): yield y - if sys.version_info[0] >= 3: - tb = yield y._traceback() - assert any( - "1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line - ) + tb = yield y._traceback() + assert any( + "1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line + ) assert "Compute Failed" in hdlr.messages["warning"][0] logger.setLevel(old_level) @@ -473,9 +472,6 @@ def f(dask_worker=None): @gen_cluster(client=True) def test_run_coroutine_dask_worker(c, s, a, b): - if sys.version_info < (3,) and tornado.version_info < (4, 5): - pytest.skip("test needs Tornado 4.5+ on Python 2.7") - @gen.coroutine def f(dask_worker=None): yield gen.sleep(0.001) @@ -586,7 +582,6 @@ def test_clean(c, s, a, b): assert not c -@pytest.mark.skipif(sys.version_info[:2] == (3, 4), reason="mul bytes fails") @gen_cluster(client=True) def test_message_breakup(c, s, a, b): n = 100000 From a5dc1be31c6270416df14152efd6a326963ea299 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 14 Apr 2020 14:15:02 +0100 Subject: [PATCH 05/14] Revert "Polish" This reverts commit 5d2566c7a306e7835e9a8c474f667fcd8ea07a7b. --- distributed/tests/test_client.py | 24 +++++++++------ distributed/tests/test_scheduler.py | 1 + distributed/tests/test_security.py | 46 ++++++++++++++++------------- distributed/tests/test_variable.py | 1 + distributed/tests/test_worker.py | 13 +++++--- 5 files changed, 52 insertions(+), 33 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 3cfad30370..297394631a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -152,7 +152,8 @@ def test_map(c, s, a, b): L4 = c.map(add, range(3), range(4)) results = yield c.gather(L4) - assert results == list(map(add, range(3), range(4))) + if sys.version_info[0] >= 3: + assert results == list(map(add, range(3), range(4))) def f(x, y=10): return x + y @@ -1438,7 +1439,9 @@ def test_many_submits_spread_evenly(c, s, a, b): def test_traceback(c, s, a, b): x = c.submit(div, 1, 0) tb = yield x.traceback() - assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) + + if sys.version_info[0] >= 3: + assert any("x / y" in line for line in pluck(3, traceback.extract_tb(tb))) @gen_cluster(client=True) @@ -1465,11 +1468,12 @@ def test_gather_traceback(c, s, a, b): def test_traceback_sync(c): x = c.submit(div, 1, 0) tb = x.traceback() - assert any( - "x / y" in line - for line in concat(traceback.extract_tb(tb)) - if isinstance(line, str) - ) + if sys.version_info[0] >= 3: + assert any( + "x / y" in line + for line in concat(traceback.extract_tb(tb)) + if isinstance(line, str) + ) y = c.submit(inc, x) tb2 = y.traceback() @@ -2593,8 +2597,9 @@ def test_run_coroutine(c, s, a, b): with pytest.raises(RuntimeError, match="hello"): yield c.run(throws, 1) - results = yield c.run(asyncinc, 2, delay=0.01) - assert results == {a.address: 3, b.address: 3} + if sys.version_info >= (3, 5): + results = yield c.run(asyncinc, 2, delay=0.01) + assert results == {a.address: 3, b.address: 3} def test_run_coroutine_sync(c, s, a, b): @@ -5207,6 +5212,7 @@ def test_scatter_direct(s, a, b): yield c.close() +@pytest.mark.skipif(sys.version_info[0] < 3, reason="cloudpickle Py27 issue") @gen_cluster(client=True) def test_unhashable_function(c, s, a, b): d = {"a": 1} diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 7551ae7f09..2b48fa030e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -794,6 +794,7 @@ def test_retire_workers_no_suspicious_tasks(c, s, a, b): @pytest.mark.skipif( sys.platform.startswith("win"), reason="file descriptors not really a thing" ) +@pytest.mark.skipif(sys.version_info < (3, 6), reason="intermittent failure") @gen_cluster(client=True, nthreads=[], timeout=240) def test_file_descriptors(c, s): yield gen.sleep(0.1) diff --git a/distributed/tests/test_security.py b/distributed/tests/test_security.py index dae035449a..8665ebead3 100644 --- a/distributed/tests/test_security.py +++ b/distributed/tests/test_security.py @@ -149,15 +149,15 @@ def test_tls_config_for_role(): sec.get_tls_config_for_role("supervisor") -def assert_many_ciphers(ctx): - assert len(ctx.get_ciphers()) > 2 # Most likely - - def test_connection_args(): def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False + def many_ciphers(ctx): + if sys.version_info >= (3, 6): + assert len(ctx.get_ciphers()) > 2 # Most likely + c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, @@ -171,12 +171,12 @@ def basic_checks(ctx): assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - assert_many_ciphers(ctx) + many_ciphers(ctx) d = sec.get_connection_args("worker") ctx = d["ssl_context"] basic_checks(ctx) - assert_many_ciphers(ctx) + many_ciphers(ctx) # No cert defined => no TLS d = sec.get_connection_args("client") @@ -193,12 +193,13 @@ def basic_checks(ctx): assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - - supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] - assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] - assert len(tls_13_ciphers) in (0, 3) + if sys.version_info >= (3, 6): + supported_ciphers = ctx.get_ciphers() + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] + assert len(tls_12_ciphers) == 1 + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] + if len(tls_13_ciphers): + assert len(tls_13_ciphers) == 3 def test_listen_args(): @@ -206,6 +207,10 @@ def basic_checks(ctx): assert ctx.verify_mode == ssl.CERT_REQUIRED assert ctx.check_hostname is False + def many_ciphers(ctx): + if sys.version_info >= (3, 6): + assert len(ctx.get_ciphers()) > 2 # Most likely + c = { "distributed.comm.tls.ca-file": ca_file, "distributed.comm.tls.scheduler.key": key1, @@ -219,12 +224,12 @@ def basic_checks(ctx): assert not d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - assert_many_ciphers(ctx) + many_ciphers(ctx) d = sec.get_listen_args("worker") ctx = d["ssl_context"] basic_checks(ctx) - assert_many_ciphers(ctx) + many_ciphers(ctx) # No cert defined => no TLS d = sec.get_listen_args("client") @@ -241,12 +246,13 @@ def basic_checks(ctx): assert d["require_encryption"] ctx = d["ssl_context"] basic_checks(ctx) - - supported_ciphers = ctx.get_ciphers() - tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] - assert len(tls_12_ciphers) == 1 - tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] - assert len(tls_13_ciphers) in (0, 3) + if sys.version_info >= (3, 6): + supported_ciphers = ctx.get_ciphers() + tls_12_ciphers = [c for c in supported_ciphers if "TLSv1.2" in c["description"]] + assert len(tls_12_ciphers) == 1 + tls_13_ciphers = [c for c in supported_ciphers if "TLSv1.3" in c["description"]] + if len(tls_13_ciphers): + assert len(tls_13_ciphers) == 3 @pytest.mark.asyncio diff --git a/distributed/tests/test_variable.py b/distributed/tests/test_variable.py index b8aaa9275c..64765d808c 100644 --- a/distributed/tests/test_variable.py +++ b/distributed/tests/test_variable.py @@ -161,6 +161,7 @@ def test_timeout_get(c, s, a, b): assert result == 1 +@pytest.mark.skipif(sys.version_info[0] == 2, reason="Multi-client issues") @pytest.mark.slow @gen_cluster(client=True, nthreads=[("127.0.0.1", 2)] * 5, Worker=Nanny, timeout=None) def test_race(c, s, *workers): diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 42cd1bf0c9..a5e364ec0c 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -136,10 +136,11 @@ def reset(self): with pytest.raises(ZeroDivisionError): yield y - tb = yield y._traceback() - assert any( - "1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line - ) + if sys.version_info[0] >= 3: + tb = yield y._traceback() + assert any( + "1 / 0" in line for line in pluck(3, traceback.extract_tb(tb)) if line + ) assert "Compute Failed" in hdlr.messages["warning"][0] logger.setLevel(old_level) @@ -472,6 +473,9 @@ def f(dask_worker=None): @gen_cluster(client=True) def test_run_coroutine_dask_worker(c, s, a, b): + if sys.version_info < (3,) and tornado.version_info < (4, 5): + pytest.skip("test needs Tornado 4.5+ on Python 2.7") + @gen.coroutine def f(dask_worker=None): yield gen.sleep(0.001) @@ -582,6 +586,7 @@ def test_clean(c, s, a, b): assert not c +@pytest.mark.skipif(sys.version_info[:2] == (3, 4), reason="mul bytes fails") @gen_cluster(client=True) def test_message_breakup(c, s, a, b): n = 100000 From f86126ae95e7f1e4361b30509121b3f4a3718aff Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Wed, 15 Apr 2020 14:26:06 +0100 Subject: [PATCH 06/14] tests --- distributed/tests/test_client.py | 13 ++++ distributed/tests/test_publish.py | 116 ++++++++++++++++-------------- 2 files changed, 75 insertions(+), 54 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 297394631a..7fe09d1d80 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5966,3 +5966,16 @@ def test_as_completed_condition_loop(c, s, a, b): def test_client_connectionpool_semaphore_loop(s, a, b): with Client(s["address"]) as c: assert c.rpc.semaphore._loop is c.loop.asyncio_loop + + +@gen_cluster(client=False) +async def test_pickle_identity_issue3227(s, a, b): + """Ensure that, on a pickle round-trip, the default client is not reinitialised. + Specifically, make sure that hostname resolution won't cause get_client() to create + a new instance. + """ + for host in ("127.0.0.1", "localhost"): + url = host + ":" + s.address.split(":")[1] + async with Client(url, asynchronous=True) as c1: + c2 = pickle.loads(pickle.dumps(c1)) + assert c2 is c1 diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index ee4b8e4e03..dde10b11cf 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -11,90 +11,97 @@ @gen_cluster(client=False) -async def test_publish_simple(s, a, b): - c = await Client(s.address, asynchronous=True, set_as_default=False) - f = await Client(s.address, asynchronous=True, set_as_default=False) +def test_publish_simple(s, a, b): + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) - data = await c.scatter(range(3)) - await c.publish_dataset(data=data) + data = yield c.scatter(range(3)) + out = yield c.publish_dataset(data=data) assert "data" in s.extensions["publish"].datasets assert isinstance(s.extensions["publish"].datasets["data"]["data"], Serialized) with pytest.raises(KeyError) as exc_info: - await c.publish_dataset(data=data) + out = yield c.publish_dataset(data=data) assert "exists" in str(exc_info.value) assert "data" in str(exc_info.value) - result = await c.scheduler.publish_list() + result = yield c.scheduler.publish_list() assert result == ("data",) - result = await f.scheduler.publish_list() + result = yield f.scheduler.publish_list() assert result == ("data",) - await c.close() - await f.close() + yield c.close() + yield f.close() @gen_cluster(client=False) -async def test_publish_non_string_key(s, a, b): - async with Client(s.address, asynchronous=True, set_as_default=False) as c: +def test_publish_non_string_key(s, a, b): + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) + + try: for name in [("a", "b"), 9.0, 8]: - data = await c.scatter(range(3)) - await c.publish_dataset(data, name=name) + data = yield c.scatter(range(3)) + out = yield c.publish_dataset(data, name=name) assert name in s.extensions["publish"].datasets assert isinstance( s.extensions["publish"].datasets[name]["data"], Serialized ) - datasets = await c.scheduler.publish_list() + datasets = yield c.scheduler.publish_list() assert name in datasets + finally: + yield c.close() + yield f.close() + @gen_cluster(client=False) -async def test_publish_roundtrip(s, a, b): - c = await Client(s.address, asynchronous=True, set_as_default=False) - f = await Client(s.address, asynchronous=True, set_as_default=False) +def test_publish_roundtrip(s, a, b): + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) - data = await c.scatter([0, 1, 2]) - await c.publish_dataset(data=data) + data = yield c.scatter([0, 1, 2]) + yield c.publish_dataset(data=data) assert "published-data" in s.who_wants[data[0].key] - result = await f.get_dataset(name="data") + result = yield f.get_dataset(name="data") assert len(result) == len(data) - out = await f.gather(result) + out = yield f.gather(result) assert out == [0, 1, 2] with pytest.raises(KeyError) as exc_info: - await f.get_dataset(name="nonexistent") + result = yield f.get_dataset(name="nonexistent") assert "not found" in str(exc_info.value) assert "nonexistent" in str(exc_info.value) - await c.close() - await f.close() + yield c.close() + yield f.close() @gen_cluster(client=True) -async def test_unpublish(c, s, a, b): - data = await c.scatter([0, 1, 2]) - await c.publish_dataset(data=data) +def test_unpublish(c, s, a, b): + data = yield c.scatter([0, 1, 2]) + yield c.publish_dataset(data=data) key = data[0].key del data - await c.scheduler.publish_delete(name="data") + yield c.scheduler.publish_delete(name="data") assert "data" not in s.extensions["publish"].datasets start = time() while key in s.who_wants: - await gen.sleep(0.01) + yield gen.sleep(0.01) assert time() < start + 5 with pytest.raises(KeyError) as exc_info: - await c.get_dataset(name="data") + result = yield c.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @@ -106,19 +113,19 @@ def test_unpublish_sync(client): client.unpublish_dataset(name="data") with pytest.raises(KeyError) as exc_info: - client.get_dataset(name="data") + result = client.get_dataset(name="data") assert "not found" in str(exc_info.value) assert "data" in str(exc_info.value) @gen_cluster(client=True) -async def test_publish_multiple_datasets(c, s, a, b): +def test_publish_multiple_datasets(c, s, a, b): x = delayed(inc)(1) y = delayed(inc)(2) - await c.publish_dataset(x=x, y=y) - datasets = await c.scheduler.publish_list() + yield c.publish_dataset(x=x, y=y) + datasets = yield c.scheduler.publish_list() assert set(datasets) == {"x", "y"} @@ -129,7 +136,7 @@ def test_unpublish_multiple_datasets_sync(client): client.unpublish_dataset(name="x") with pytest.raises(KeyError) as exc_info: - client.get_dataset(name="x") + result = client.get_dataset(name="x") datasets = client.list_datasets() assert set(datasets) == {"y"} @@ -140,17 +147,17 @@ def test_unpublish_multiple_datasets_sync(client): client.unpublish_dataset(name="y") with pytest.raises(KeyError) as exc_info: - client.get_dataset(name="y") + result = client.get_dataset(name="y") assert "not found" in str(exc_info.value) assert "y" in str(exc_info.value) @gen_cluster(client=False) -async def test_publish_bag(s, a, b): +def test_publish_bag(s, a, b): db = pytest.importorskip("dask.bag") - c = await Client(s.address, asynchronous=True, set_as_default=False) - f = await Client(s.address, asynchronous=True, set_as_default=False) + c = yield Client(s.address, asynchronous=True) + f = yield Client(s.address, asynchronous=True) bag = db.from_sequence([0, 1, 2]) bagp = c.persist(bag) @@ -159,19 +166,19 @@ async def test_publish_bag(s, a, b): keys = {f.key for f in futures_of(bagp)} assert keys == set(bag.dask) - await c.publish_dataset(data=bagp) + yield c.publish_dataset(data=bagp) # check that serialization didn't affect original bag's dask assert len(futures_of(bagp)) == 3 - result = await f.get_dataset("data") + result = yield f.get_dataset("data") assert set(result.dask.keys()) == set(bagp.dask.keys()) assert {f.key for f in result.dask.values()} == {f.key for f in bagp.dask.values()} - out = await f.compute(result) + out = yield f.compute(result) assert out == [0, 1, 2] - await c.close() - await f.close() + yield c.close() + yield f.close() def test_datasets_setitem(client): @@ -216,18 +223,19 @@ def test_datasets_iter(client): @gen_cluster(client=True) -async def test_pickle_safe(c, s, a, b): - async with Client( - s.address, asynchronous=True, serializers=["msgpack"], set_as_default=False - ) as c2: - await c2.publish_dataset(x=[1, 2, 3]) - result = await c2.get_dataset("x") +def test_pickle_safe(c, s, a, b): + c2 = yield Client(s.address, asynchronous=True, serializers=["msgpack"]) + try: + yield c2.publish_dataset(x=[1, 2, 3]) + result = yield c2.get_dataset("x") assert result == [1, 2, 3] with pytest.raises(TypeError): - await c2.publish_dataset(y=lambda x: x) + yield c2.publish_dataset(y=lambda x: x) - await c.publish_dataset(z=lambda x: x) # this can use pickle + yield c.publish_dataset(z=lambda x: x) # this can use pickle with pytest.raises(TypeError): - await c2.get_dataset("z") + yield c2.get_dataset("z") + finally: + yield c2.close() From 77a0d8aaf60fe4513b6c26f40c456ddbab86e882 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Thu, 16 Apr 2020 18:26:48 +0100 Subject: [PATCH 07/14] revert --- distributed/tests/test_client.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 7fe09d1d80..297394631a 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5966,16 +5966,3 @@ def test_as_completed_condition_loop(c, s, a, b): def test_client_connectionpool_semaphore_loop(s, a, b): with Client(s["address"]) as c: assert c.rpc.semaphore._loop is c.loop.asyncio_loop - - -@gen_cluster(client=False) -async def test_pickle_identity_issue3227(s, a, b): - """Ensure that, on a pickle round-trip, the default client is not reinitialised. - Specifically, make sure that hostname resolution won't cause get_client() to create - a new instance. - """ - for host in ("127.0.0.1", "localhost"): - url = host + ":" + s.address.split(":")[1] - async with Client(url, asynchronous=True) as c1: - c2 = pickle.loads(pickle.dumps(c1)) - assert c2 is c1 From 9d6c8c064c7ff5381229a5ca80ea3d5c4e262924 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 17 Apr 2020 08:58:21 +0100 Subject: [PATCH 08/14] revert --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index cf6732eaa7..86ee425adf 100644 --- a/.gitignore +++ b/.gitignore @@ -17,4 +17,3 @@ dask-worker-space/ *.swp .ycm_extra_conf.py tags -.ipynb_checkpoints From 21522df27ff5befcb67cdc07504a8f0adc0a6b9a Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 20 Apr 2020 10:50:51 +0100 Subject: [PATCH 09/14] xfail --- distributed/tests/test_steal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index fb5c96e14e..5ef3e5330e 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -106,6 +106,7 @@ async def test_worksteal_many_thieves(c, s, *workers): assert sum(map(len, s.has_what.values())) < 150 +@pytest.mark.xfail(reason="GH#3574") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2) async def test_dont_steal_unknown_functions(c, s, a, b): futures = c.map(inc, range(100), workers=a.address, allow_other_workers=True) From 4b55b310622cb22aabef29fc1b9ee6ccca07be53 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 20 Apr 2020 12:20:19 +0100 Subject: [PATCH 10/14] Better async functions --- distributed/publish.py | 51 +++++++++++++++++++++++++++---- distributed/tests/test_publish.py | 17 +++++++++++ 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/distributed/publish.py b/distributed/publish.py index 236c3a4707..4b30ebde04 100644 --- a/distributed/publish.py +++ b/distributed/publish.py @@ -59,21 +59,60 @@ class Datasets(MutableMapping): """ + __slots__ = ("_client",) + def __init__(self, client): - self.__client = client + self._client = client def __getitem__(self, key): - return self.__client.get_dataset(key) + # When client is asynchronous, it returns a coroutine + return self._client.get_dataset(key) def __setitem__(self, key, value): - self.__client.publish_dataset(value, name=key) + if self._client.asynchronous: + # 'await obj[key] = value' is not supported by Python as of 3.8 + raise TypeError( + "Can't use 'client.datasets[name] = value' when client is " + "asynchronous; please use 'client.publish_dataset(name=value)' instead" + ) + self._client.publish_dataset(value, name=key) def __delitem__(self, key): - self.__client.unpublish_dataset(key) + if self._client.asynchronous: + # 'await del obj[key]' is not supported by Python as of 3.8 + raise TypeError( + "Can't use 'del client.datasets[name]' when client is asynchronous; " + "please use 'client.unpublish_dataset(name)' instead" + ) + return self._client.unpublish_dataset(key) def __iter__(self): - for key in self.__client.list_datasets(): + if self._client.asynchronous: + raise TypeError( + "Can't invoke iter() or 'for' on client.datasets when client is " + "asynchronous; use 'async for' instead" + ) + for key in self._client.list_datasets(): yield key + def __aiter__(self): + if not self._client.asynchronous: + raise TypeError( + "Can't invoke 'async for' on client.datasets when client is " + "synchronous; use iter() or 'for' instead" + ) + + async def _(): + for key in await self._client.list_datasets(): + yield key + + return _() + def __len__(self): - return len(self.__client.list_datasets()) + if self._client.asynchronous: + # 'await len(obj)' is not supported by Python as of 3.8 + raise TypeError( + "Can't use 'len(client.datasets)' when client is asynchronous; " + "please use 'len(await client.list_datasets())' instead" + ) + return len(self._client.list_datasets()) diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index ab32d52a11..4284bbba5f 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -213,6 +213,23 @@ def test_datasets_iter(client): client.publish_dataset(**{str(key): key for key in keys}) for n, key in enumerate(client.datasets): assert key == str(n) + with pytest.raises(TypeError): + client.datasets.__aiter__() + + +@gen_cluster(client=True) +async def test_datasets_async(c, s, a, b): + await c.publish_dataset(foo=1, bar=2) + assert await c.datasets["foo"] == 1 + assert {k async for k in c.datasets} == {"foo", "bar"} + with pytest.raises(TypeError): + c.datasets["baz"] = 3 + with pytest.raises(TypeError): + del c.datasets["foo"] + with pytest.raises(TypeError): + next(iter(c.datasets)) + with pytest.raises(TypeError): + len(c.datasets) @gen_cluster(client=True) From fb8f7770204c2361adbc39cc365f8fba5eef6980 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 20 Apr 2020 16:01:45 +0100 Subject: [PATCH 11/14] Use contextvars to deserialize Future --- .github/workflows/ci-windows.yaml | 7 +++++ continuous_integration/travis/install.sh | 4 +++ distributed/client.py | 35 +++++++++++++++++++----- distributed/tests/test_publish.py | 31 +++++++++++++++++++++ requirements.txt | 1 + 5 files changed, 71 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci-windows.yaml b/.github/workflows/ci-windows.yaml index e0c95d0f23..4b73675a50 100644 --- a/.github/workflows/ci-windows.yaml +++ b/.github/workflows/ci-windows.yaml @@ -23,6 +23,13 @@ jobs: activate-environment: testenv auto-activate-base: false + - name: Install contextvars + shell: bash -l {0} + run: | + if [[ "${{ matrix.python-version }}" = "3.6" ]]; then + conda install -c conda-forge contextvars + fi + - name: Install tornado shell: bash -l {0} run: | diff --git a/continuous_integration/travis/install.sh b/continuous_integration/travis/install.sh index 4ee0790f6c..3420275d5a 100644 --- a/continuous_integration/travis/install.sh +++ b/continuous_integration/travis/install.sh @@ -55,6 +55,10 @@ conda install -c conda-forge -q \ zstandard \ $PACKAGES +if [[ $PYTHON == 3.6 ]]; then + conda install -c conda-forge -c defaults contextvars +fi + # stacktrace is not currently avaiable for Python 3.8. # Remove the version check block below when it is avaiable. if [[ $PYTHON != 3.8 ]]; then diff --git a/distributed/client.py b/distributed/client.py index 884ef7b614..11b2d95b1c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import DoneAndNotDoneFutures from contextlib import contextmanager +from contextvars import ContextVar import copy import errno from functools import partial @@ -90,6 +91,7 @@ _global_clients = weakref.WeakValueDictionary() _global_client_index = [0] +_deserialize_client = ContextVar("_deserialize_client", default=None) DEFAULT_EXTENSIONS = [PubSubClientExtension] @@ -163,7 +165,7 @@ def __init__(self, key, client=None, inform=True, state=None): self.key = key self._cleared = False tkey = tokey(key) - self.client = client or _get_global_client() + self.client = client or _deserialize_client.get() or _get_global_client() self.client._inc_ref(tkey) self._generation = self.client.generation @@ -354,11 +356,11 @@ def release(self, _in_destructor=False): pass # Shutting down, add_callback may be None def __getstate__(self): - return (self.key, self.client.scheduler.address) + return self.key, self.client.scheduler.address def __setstate__(self, state): key, address = state - c = get_client(address) + c = _deserialize_client.get() or get_client(address) Future.__init__(self, key, c) c._send_to_scheduler( { @@ -2175,8 +2177,7 @@ def retry(self, futures, asynchronous=None): """ return self.sync(self._retry, futures, asynchronous=asynchronous) - @gen.coroutine - def _publish_dataset(self, *args, name=None, **kwargs): + async def _publish_dataset(self, *args, name=None, **kwargs): with log_errors(): coroutines = [] @@ -2202,7 +2203,7 @@ def add_coro(name, data): for name, data in kwargs.items(): add_coro(name, data) - yield coroutines + await asyncio.gather(*coroutines) def publish_dataset(self, *args, **kwargs): """ @@ -2282,7 +2283,27 @@ def list_datasets(self, **kwargs): return self.sync(self.scheduler.publish_list, **kwargs) async def _get_dataset(self, name): - out = await self.scheduler.publish_get(name=name, client=self.id) + if sys.version_info >= (3, 7): + # Insulate contextvars change with a task + async def _(): + _deserialize_client.set(self) + return await self.scheduler.publish_get(name=name, client=self.id) + + out = await asyncio.create_task(_()) + else: + # Python 3.6; creating a task doesn't copy the context. + # We can still detect a race condition though. + if _deserialize_client.get() not in (self, None): + raise RuntimeError( # pragma: nocover + "Detected race condition where get_dataset() is invoked in " + "parallel by multiple clients. Please upgrade to Python 3.7+." + ) + tok = _deserialize_client.set(self) + try: + out = await self.scheduler.publish_get(name=name, client=self.id) + finally: + _deserialize_client.reset(tok) + if out is None: raise KeyError(f"Dataset '{name}' not found") return out["data"] diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index 4284bbba5f..dfb7518d64 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -246,3 +246,34 @@ async def test_pickle_safe(c, s, a, b): with pytest.raises(TypeError): await c2.get_dataset("z") + + +@gen_cluster(client=True) +async def test_deserialize_client(c, s, a, b): + """Test that the client attached to Futures returned by Client.get_dataset is always + the instance of the client that invoked the method. + Specifically: + + - when the client is defined by hostname, test that it is not accidentally + reinitialised by IP; + - when multiple clients are connected to the same scheduler, test that they don't + interfere with each other. + + See: https://github.com/dask/distributed/issues/3227 + """ + future = await c.scatter("123") + await c.publish_dataset(foo=future) + future = await c.get_dataset("foo") + assert future.client is c + + for addr in (s.address, "localhost:" + s.address.split(":")[-1]): + async with Client(addr, asynchronous=True) as c2: + future = await c.get_dataset("foo") + assert future.client is c + future = await c2.get_dataset("foo") + assert future.client is c2 + + # Ensure cleanup + from distributed.client import _deserialize_client + + assert _deserialize_client.get() is None diff --git a/requirements.txt b/requirements.txt index 4cb3ba60ae..b0d20cdb1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ click >= 6.6 cloudpickle >= 0.2.2 +contextvars;python_version<'3.7' dask >= 2.9.0 msgpack >= 0.6.0 psutil >= 5.0 From 042eddfff0aecdc5d61f7a18dc55ddb4c7b1034c Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 27 Apr 2020 13:18:04 +0100 Subject: [PATCH 12/14] Redesign --- distributed/client.py | 77 ++++++++++----- distributed/lock.py | 8 +- distributed/queues.py | 6 +- distributed/tests/test_client.py | 154 +++++++++++++++++++++++++----- distributed/tests/test_publish.py | 5 +- distributed/variable.py | 6 +- distributed/worker.py | 9 +- 7 files changed, 203 insertions(+), 62 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 45e75fea39..18df021b5c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -90,7 +90,7 @@ _global_clients = weakref.WeakValueDictionary() _global_client_index = [0] -_deserialize_client = ContextVar("_deserialize_client", default=None) +_current_client = ContextVar("_current_client", default=None) DEFAULT_EXTENSIONS = [PubSubClientExtension] @@ -164,7 +164,7 @@ def __init__(self, key, client=None, inform=True, state=None): self.key = key self._cleared = False tkey = tokey(key) - self.client = client or _deserialize_client.get() or _get_global_client() + self.client = client or Client.current() self.client._inc_ref(tkey) self._generation = self.client.generation @@ -359,7 +359,10 @@ def __getstate__(self): def __setstate__(self, state): key, address = state - c = _deserialize_client.get() or get_client(address) + try: + c = Client.current(allow_global=False) + except ValueError: + c = get_client(address) Future.__init__(self, key, c) c._send_to_scheduler( { @@ -729,10 +732,43 @@ def __init__( ReplayExceptionClient(self) + @contextmanager + def as_current(self): + """Thread-local, Task-local context manager that causes the Client.current class + method to return self. This is used when a method of Client needs to propagate a + reference to self deep into the stack through generic methods that shouldn't be + aware of this class. + """ + # Python 3.6; contextvars are thread-local but not Task-local. + # We can still detect a race condition. + if sys.version_info < (3, 7) and _current_client.get() not in (self, None): + raise RuntimeError( # pragma: nocover + "Detected race condition where get_dataset() is invoked in " + "parallel by multiple asynchronous clients. " + "Please upgrade to Python 3.7+." + ) + + tok = _current_client.set(self) + try: + yield + finally: + _current_client.reset(tok) + @classmethod - def current(cls): - """ Return global client if one exists, otherwise raise ValueError """ - return default_client() + def current(cls, allow_global=True): + """When running within the context of `as_client`, return the context-local + current client. Otherwise, return an arbitrary already existing Client. + If no Client instances exist, raise ValueError. + + If allow_global is set to False, raise ValueError if running outside of the + `as_client` context manager. + """ + out = _current_client.get() + if out: + return out + if allow_global: + return default_client() + raise ValueError("Not running inside the `as_current` context manager") @property def asynchronous(self): @@ -2286,26 +2322,8 @@ def list_datasets(self, **kwargs): return self.sync(self.scheduler.publish_list, **kwargs) async def _get_dataset(self, name): - if sys.version_info >= (3, 7): - # Insulate contextvars change with a task - async def _(): - _deserialize_client.set(self) - return await self.scheduler.publish_get(name=name, client=self.id) - - out = await asyncio.create_task(_()) - else: - # Python 3.6; creating a task doesn't copy the context. - # We can still detect a race condition though. - if _deserialize_client.get() not in (self, None): - raise RuntimeError( # pragma: nocover - "Detected race condition where get_dataset() is invoked in " - "parallel by multiple clients. Please upgrade to Python 3.7+." - ) - tok = _deserialize_client.set(self) - try: - out = await self.scheduler.publish_get(name=name, client=self.id) - finally: - _deserialize_client.reset(tok) + with self.as_current(): + out = await self.scheduler.publish_get(name=name, client=self.id) if out is None: raise KeyError(f"Dataset '{name}' not found") @@ -4715,6 +4733,13 @@ def __exit__(self, typ, value, traceback): def temp_default_client(c): """ Set the default client for the duration of the context + .. note:: + This function should be used for unit testing exclusively. In all other cases, + please use ``Client.as_current`` instead. + + .. note:: + Unlike Client.as_current, this function is not thread-local. + Parameters ---------- c : Client diff --git a/distributed/lock.py b/distributed/lock.py index 3c893a419c..7a55ccb441 100644 --- a/distributed/lock.py +++ b/distributed/lock.py @@ -3,7 +3,7 @@ import logging import uuid -from .client import _get_global_client +from .client import Client from .utils import log_errors, TimeoutError from .worker import get_worker @@ -93,7 +93,11 @@ class Lock: """ def __init__(self, name=None, client=None): - self.client = client or _get_global_client() or get_worker().client + try: + self.client = client or Client.current() + except ValueError: + # Initialise new client + self.client = get_worker().client self.name = name or "lock-" + uuid.uuid4().hex self.id = uuid.uuid4().hex self._locked = False diff --git a/distributed/queues.py b/distributed/queues.py index 81262703ad..324fb46c40 100644 --- a/distributed/queues.py +++ b/distributed/queues.py @@ -3,7 +3,7 @@ import logging import uuid -from .client import Future, _get_global_client, Client +from .client import Future, Client from .utils import tokey, sync, thread_state from .worker import get_client @@ -148,7 +148,7 @@ class Queue: not given, a random name will be generated. client: Client (optional) Client used for communication with the scheduler. Defaults to the - value of ``_get_global_client()``. + value of ``Client.current()``. maxsize: int (optional) Number of items allowed in the queue. If 0 (the default), the queue size is unbounded. @@ -167,7 +167,7 @@ class Queue: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or _get_global_client() + self.client = client or Client.current() self.name = name or "queue-" + uuid.uuid4().hex self._event_started = asyncio.Event() if self.client.asynchronous or getattr( diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index fd95895c84..64f47f4e42 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1201,7 +1201,7 @@ async def test_get_releases_data(c, s, a, b): assert time() < start + 2 -def test_Current(s, a, b): +def test_current(s, a, b): with Client(s["address"]) as c: assert Client.current() is c with pytest.raises(ValueError): @@ -3876,38 +3876,148 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b): @gen_cluster(client=False) async def test_serialize_future(s, a, b): - c = await Client(s.address, asynchronous=True) - f = await Client(s.address, asynchronous=True) + c1 = Client(s.address, asynchronous=True) + c2 = Client(s.address, asynchronous=True) - future = c.submit(lambda: 1) + future = c1.submit(lambda: 1) result = await future - with temp_default_client(f): - future2 = pickle.loads(pickle.dumps(future)) - assert future2.client is f - assert tokey(future2.key) in f.futures - result2 = await future2 - assert result == result2 + for ci in (c1, c2): + for ctxman in ci.as_current, lambda: temp_default_client(ci): + with ctxman(): + future2 = pickle.loads(pickle.dumps(future)) + assert future2.client is ci + assert tokey(future2.key) in ci.futures + result2 = await future2 + assert result == result2 - await c.close() - await f.close() + await c1.close() + await c2.close() @gen_cluster(client=False) -async def test_temp_client(s, a, b): - c = await Client(s.address, asynchronous=True) - f = await Client(s.address, asynchronous=True) +async def test_temp_default_client(s, a, b): + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) + + with temp_default_client(c1): + assert default_client() is c1 + assert default_client(c2) is c2 + + with temp_default_client(c2): + assert default_client() is c2 + assert default_client(c1) is c1 + + await c1.close() + await c2.close() + + +@gen_cluster(client=True) +async def test_as_current(c, s, a, b): + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) with temp_default_client(c): - assert default_client() is c - assert default_client(f) is f + assert Client.current() is c + with pytest.raises(ValueError): + Client.current(allow_global=False) + with c1.as_current(): + assert Client.current() is c1 + assert Client.current(allow_global=True) is c1 + with c2.as_current(): + assert Client.current() is c2 + assert Client.current(allow_global=True) is c2 - with temp_default_client(f): - assert default_client() is f - assert default_client(c) is c + await c1.close() + await c2.close() - await c.close() - await f.close() + +def test_as_current_is_thread_local(s): + l1 = threading.Lock() + l2 = threading.Lock() + l3 = threading.Lock() + l4 = threading.Lock() + l1.acquire() + l2.acquire() + l3.acquire() + l4.acquire() + + def run1(): + with Client(s.address) as c: + with c.as_current(): + l1.acquire() + l2.release() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + l3.acquire() + l4.release() + + def run2(): + with Client(s.address) as c: + with c.as_current(): + l1.release() + l2.acquire() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + l3.release() + l4.acquire() + + t1 = threading.Thread(target=run1) + t2 = threading.Thread(target=run2) + t1.start() + t2.start() + t1.join() + t2.join() + + +@pytest.mark.xfail( + sys.version_info < (3, 7), + reason="Python 3.6 contextvars are not copied on Task creation", +) +@gen_cluster(client=False) +async def test_as_current_is_task_local(s, a, b): + l1 = asyncio.Lock() + l2 = asyncio.Lock() + l3 = asyncio.Lock() + l4 = asyncio.Lock() + await l1.acquire() + await l2.acquire() + await l3.acquire() + await l4.acquire() + + async def run1(): + async with Client(s.address, asynchronous=True) as c: + with c.as_current(): + await l1.acquire() + l2.release() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + await l3.acquire() + l4.release() + + async def run2(): + async with Client(s.address, asynchronous=True) as c: + with c.as_current(): + l1.release() + await l2.acquire() + try: + # This line runs only when both run1 and run2 are inside the + # context manager + assert Client.current(allow_global=False) is c + finally: + l3.release() + await l4.acquire() + + await asyncio.gather(run1(), run2()) @nodebug # test timing is fragile diff --git a/distributed/tests/test_publish.py b/distributed/tests/test_publish.py index dfb7518d64..a789f5a47f 100644 --- a/distributed/tests/test_publish.py +++ b/distributed/tests/test_publish.py @@ -259,6 +259,7 @@ async def test_deserialize_client(c, s, a, b): - when multiple clients are connected to the same scheduler, test that they don't interfere with each other. + See: test_client.test_serialize_future See: https://github.com/dask/distributed/issues/3227 """ future = await c.scatter("123") @@ -274,6 +275,6 @@ async def test_deserialize_client(c, s, a, b): assert future.client is c2 # Ensure cleanup - from distributed.client import _deserialize_client + from distributed.client import _current_client - assert _deserialize_client.get() is None + assert _current_client.get() is None diff --git a/distributed/variable.py b/distributed/variable.py index a47064b139..dc717533a2 100644 --- a/distributed/variable.py +++ b/distributed/variable.py @@ -5,7 +5,7 @@ from tlz import merge -from .client import Future, _get_global_client, Client +from .client import Future, Client from .utils import tokey, log_errors, TimeoutError, ignoring from .worker import get_client @@ -142,7 +142,7 @@ class Variable: If not given, a random name will be generated. client: Client (optional) Client used for communication with the scheduler. Defaults to the - value of ``_get_global_client()``. + value of ``Client.current()``. Examples -------- @@ -161,7 +161,7 @@ class Variable: """ def __init__(self, name=None, client=None, maxsize=0): - self.client = client or _get_global_client() + self.client = client or Client.current() self.name = name or "variable-" + uuid.uuid4().hex async def _set(self, value): diff --git a/distributed/worker.py b/distributed/worker.py index ef95c1f4b7..8dd043438f 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3091,14 +3091,15 @@ def get_client(address=None, timeout=3, resolve_address=True): if not address or worker.scheduler.address == address: return worker._get_client(timeout=timeout) - from .client import _get_global_client + from .client import Client - client = _get_global_client() # TODO: assumes the same scheduler + try: + client = Client.current() # TODO: assumes the same scheduler + except ValueError: + client = None if client and (not address or client.scheduler.address == address): return client elif address: - from .client import Client - return Client(address, timeout=timeout) else: raise ValueError("No global client found and no address provided") From 621573db0d29a085eb821d733c24d9d316ca6992 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 27 Apr 2020 13:24:58 +0100 Subject: [PATCH 13/14] Tweaks --- distributed/client.py | 12 ++++++------ distributed/tests/test_client.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 18df021b5c..5ec11d9bd7 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -742,7 +742,7 @@ def as_current(self): # Python 3.6; contextvars are thread-local but not Task-local. # We can still detect a race condition. if sys.version_info < (3, 7) and _current_client.get() not in (self, None): - raise RuntimeError( # pragma: nocover + raise RuntimeError( "Detected race condition where get_dataset() is invoked in " "parallel by multiple asynchronous clients. " "Please upgrade to Python 3.7+." @@ -757,9 +757,8 @@ def as_current(self): @classmethod def current(cls, allow_global=True): """When running within the context of `as_client`, return the context-local - current client. Otherwise, return an arbitrary already existing Client. + current client. Otherwise, return the latest initialised Client. If no Client instances exist, raise ValueError. - If allow_global is set to False, raise ValueError if running outside of the `as_client` context manager. """ @@ -4734,11 +4733,12 @@ def temp_default_client(c): """ Set the default client for the duration of the context .. note:: - This function should be used for unit testing exclusively. In all other cases, - please use ``Client.as_current`` instead. + This function should be used exclusively for unit testing the default client + functionality. In all other cases, please use ``Client.as_current`` instead. .. note:: - Unlike Client.as_current, this function is not thread-local. + Unlike ``Client.as_current``, this context manager is neither thread-local nor + task-local. Parameters ---------- diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 64f47f4e42..634194bbae 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -3876,8 +3876,8 @@ async def test_scatter_compute_store_lose_processing(c, s, a, b): @gen_cluster(client=False) async def test_serialize_future(s, a, b): - c1 = Client(s.address, asynchronous=True) - c2 = Client(s.address, asynchronous=True) + c1 = await Client(s.address, asynchronous=True) + c2 = await Client(s.address, asynchronous=True) future = c1.submit(lambda: 1) result = await future From e70303a9fcc0857ae448cd0d6f9582831e7017cb Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 5 May 2020 11:49:46 +0100 Subject: [PATCH 14/14] docstrings --- distributed/client.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index 5ec11d9bd7..18d77ffab9 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -735,16 +735,15 @@ def __init__( @contextmanager def as_current(self): """Thread-local, Task-local context manager that causes the Client.current class - method to return self. This is used when a method of Client needs to propagate a - reference to self deep into the stack through generic methods that shouldn't be - aware of this class. + method to return self. Any Future objects deserialized inside this context + manager will be automatically attached to this Client. """ - # Python 3.6; contextvars are thread-local but not Task-local. - # We can still detect a race condition. + # In Python 3.6, contextvars are thread-local but not Task-local. + # We can still detect a race condition though. if sys.version_info < (3, 7) and _current_client.get() not in (self, None): raise RuntimeError( - "Detected race condition where get_dataset() is invoked in " - "parallel by multiple asynchronous clients. " + "Detected race condition where multiple asynchronous clients tried " + "entering the as_current() context manager at the same time. " "Please upgrade to Python 3.7+." )