From 15905fd44c67e69adad1f9b36db9a264bc9a350b Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 19 Aug 2022 10:47:38 +0100 Subject: [PATCH] cleanup ipywidgets mocking --- distributed/diagnostics/tests/test_widgets.py | 89 +++++++++---------- 1 file changed, 44 insertions(+), 45 deletions(-) diff --git a/distributed/diagnostics/tests/test_widgets.py b/distributed/diagnostics/tests/test_widgets.py index d8d8e9fa6b..34dff54c09 100644 --- a/distributed/diagnostics/tests/test_widgets.py +++ b/distributed/diagnostics/tests/test_widgets.py @@ -1,13 +1,29 @@ from __future__ import annotations +import contextlib +import re +from operator import add +from unittest import mock + import pytest from packaging.version import parse as parse_version +from tlz import valmap + +from distributed.client import wait +from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws +from distributed.worker import dumps_task ipywidgets = pytest.importorskip("ipywidgets") from ipykernel.comm import Comm from ipywidgets import Widget +from distributed.diagnostics.progressbar import ( + MultiProgressWidget, + ProgressWidget, + progress, +) + ################# # Utility stuff # ################# @@ -31,68 +47,42 @@ def close(self, *args, **kwargs): pass -_widget_attrs = {} -displayed = [] -undefined = object() - - -def setup(): - _widget_attrs["_comm_default"] = getattr(Widget, "_comm_default", undefined) - Widget._comm_default = lambda self: DummyComm() - if parse_version(ipywidgets.__version__) >= parse_version("8.0.0"): - display_attr = "_repr_mimebundle_" - else: - display_attr = "_ipython_display_" - _widget_attrs[display_attr] = getattr(Widget, display_attr) +_DISPLAY_ATTR = ( + "_repr_mimebundle_" + if parse_version(ipywidgets.__version__) >= parse_version("8.0.0") + else "_ipython_display_" +) - def raise_not_implemented(*args, **kwargs): - raise NotImplementedError() - Widget._ipython_display_ = raise_not_implemented +def _comm_default(self): + return DummyComm() -def teardown(): - for attr, value in _widget_attrs.items(): - if value is undefined: - delattr(Widget, attr) - else: - setattr(Widget, attr, value) +@contextlib.contextmanager +def mock_widget(): + with mock.patch( + f"ipywidgets.Widget.{_DISPLAY_ATTR}", side_effect=NotImplementedError + ): + assert not hasattr(Widget, "_comm_default") + Widget._comm_default = _comm_default + try: + yield + finally: + del Widget._comm_default def f(**kwargs): pass -def clear_display(): - global displayed - displayed = [] - - -def record_display(*args): - displayed.extend(args) - - # End code taken from ipywidgets ##################### # Distributed stuff # ##################### -import re -from operator import add - -from tlz import valmap - -from distributed.client import wait -from distributed.diagnostics.progressbar import ( - MultiProgressWidget, - ProgressWidget, - progress, -) -from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws -from distributed.worker import dumps_task - +@mock_widget() @gen_cluster(client=True) async def test_progressbar_widget(c, s, a, b): x = c.submit(inc, 1) @@ -110,6 +100,7 @@ async def test_progressbar_widget(c, s, a, b): await progress.listen() +@mock_widget() @gen_cluster(client=True) async def test_multi_progressbar_widget(c, s, a, b): x1 = c.submit(inc, 1) @@ -150,6 +141,7 @@ async def test_multi_progressbar_widget(c, s, a, b): assert sorted(capacities, reverse=True) == capacities +@mock_widget() @gen_cluster() async def test_multi_progressbar_widget_after_close(s, a, b): s.update_graph( @@ -181,6 +173,7 @@ async def test_multi_progressbar_widget_after_close(s, a, b): assert "x" in p.bars +@mock_widget() def test_values(client): L = [client.submit(inc, i) for i in range(5)] wait(L) @@ -198,6 +191,7 @@ def test_values(client): assert p.status == "error" +@mock_widget() def test_progressbar_done(client): L = [client.submit(inc, i) for i in range(5)] wait(L) @@ -224,6 +218,7 @@ def test_progressbar_done(client): assert repr(e) in p.elapsed_time.value +@mock_widget() def test_progressbar_cancel(client): import time @@ -236,6 +231,7 @@ def test_progressbar_cancel(client): assert p.bar.value == 0 # no tasks finish before cancel is called +@mock_widget() @gen_cluster() async def test_multibar_complete(s, a, b): s.update_graph( @@ -270,6 +266,7 @@ async def test_multibar_complete(s, a, b): assert "2 / 2" in p.bar_texts["y"].value +@mock_widget() def test_fast(client): L = client.map(inc, range(100)) L2 = client.map(dec, L) @@ -279,6 +276,7 @@ def test_fast(client): assert set(p._last_response["all"]) == {"inc", "dec", "add"} +@mock_widget() @gen_cluster(client=True, client_kwargs={"serializers": ["msgpack"]}) async def test_serializers(c, s, a, b): x = c.submit(inc, 1) @@ -293,6 +291,7 @@ async def test_serializers(c, s, a, b): assert "3 / 3" in progress.bar_text.value +@mock_widget() @gen_tls_cluster(client=True) async def test_tls(c, s, a, b): x = c.submit(inc, 1)