Skip to content

Commit

Permalink
cleanup ipywidgets mocking (dask#6918)
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert authored and gjoseph92 committed Oct 31, 2022
1 parent 4c9bb19 commit 181a3c3
Showing 1 changed file with 44 additions and 45 deletions.
89 changes: 44 additions & 45 deletions distributed/diagnostics/tests/test_widgets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from __future__ import annotations

import contextlib
import re
from operator import add
from unittest import mock

import pytest
from packaging.version import parse as parse_version
from tlz import valmap

from distributed.client import wait
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
from distributed.worker import dumps_task

ipywidgets = pytest.importorskip("ipywidgets")

from ipykernel.comm import Comm
from ipywidgets import Widget

from distributed.diagnostics.progressbar import (
MultiProgressWidget,
ProgressWidget,
progress,
)

#################
# Utility stuff #
#################
Expand All @@ -31,68 +47,42 @@ def close(self, *args, **kwargs):
pass


_widget_attrs = {}
displayed = []
undefined = object()


def setup():
_widget_attrs["_comm_default"] = getattr(Widget, "_comm_default", undefined)
Widget._comm_default = lambda self: DummyComm()
if parse_version(ipywidgets.__version__) >= parse_version("8.0.0"):
display_attr = "_repr_mimebundle_"
else:
display_attr = "_ipython_display_"
_widget_attrs[display_attr] = getattr(Widget, display_attr)
_DISPLAY_ATTR = (
"_repr_mimebundle_"
if parse_version(ipywidgets.__version__) >= parse_version("8.0.0")
else "_ipython_display_"
)

def raise_not_implemented(*args, **kwargs):
raise NotImplementedError()

Widget._ipython_display_ = raise_not_implemented
def _comm_default(self):
return DummyComm()


def teardown():
for attr, value in _widget_attrs.items():
if value is undefined:
delattr(Widget, attr)
else:
setattr(Widget, attr, value)
@contextlib.contextmanager
def mock_widget():
with mock.patch(
f"ipywidgets.Widget.{_DISPLAY_ATTR}", side_effect=NotImplementedError
):
assert not hasattr(Widget, "_comm_default")
Widget._comm_default = _comm_default
try:
yield
finally:
del Widget._comm_default


def f(**kwargs):
pass


def clear_display():
global displayed
displayed = []


def record_display(*args):
displayed.extend(args)


# End code taken from ipywidgets

#####################
# Distributed stuff #
#####################

import re
from operator import add

from tlz import valmap

from distributed.client import wait
from distributed.diagnostics.progressbar import (
MultiProgressWidget,
ProgressWidget,
progress,
)
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
from distributed.worker import dumps_task


@mock_widget()
@gen_cluster(client=True)
async def test_progressbar_widget(c, s, a, b):
x = c.submit(inc, 1)
Expand All @@ -110,6 +100,7 @@ async def test_progressbar_widget(c, s, a, b):
await progress.listen()


@mock_widget()
@gen_cluster(client=True)
async def test_multi_progressbar_widget(c, s, a, b):
x1 = c.submit(inc, 1)
Expand Down Expand Up @@ -150,6 +141,7 @@ async def test_multi_progressbar_widget(c, s, a, b):
assert sorted(capacities, reverse=True) == capacities


@mock_widget()
@gen_cluster()
async def test_multi_progressbar_widget_after_close(s, a, b):
s.update_graph(
Expand Down Expand Up @@ -181,6 +173,7 @@ async def test_multi_progressbar_widget_after_close(s, a, b):
assert "x" in p.bars


@mock_widget()
def test_values(client):
L = [client.submit(inc, i) for i in range(5)]
wait(L)
Expand All @@ -198,6 +191,7 @@ def test_values(client):
assert p.status == "error"


@mock_widget()
def test_progressbar_done(client):
L = [client.submit(inc, i) for i in range(5)]
wait(L)
Expand All @@ -224,6 +218,7 @@ def test_progressbar_done(client):
assert repr(e) in p.elapsed_time.value


@mock_widget()
def test_progressbar_cancel(client):
import time

Expand All @@ -236,6 +231,7 @@ def test_progressbar_cancel(client):
assert p.bar.value == 0 # no tasks finish before cancel is called


@mock_widget()
@gen_cluster()
async def test_multibar_complete(s, a, b):
s.update_graph(
Expand Down Expand Up @@ -270,6 +266,7 @@ async def test_multibar_complete(s, a, b):
assert "2 / 2" in p.bar_texts["y"].value


@mock_widget()
def test_fast(client):
L = client.map(inc, range(100))
L2 = client.map(dec, L)
Expand All @@ -279,6 +276,7 @@ def test_fast(client):
assert set(p._last_response["all"]) == {"inc", "dec", "add"}


@mock_widget()
@gen_cluster(client=True, client_kwargs={"serializers": ["msgpack"]})
async def test_serializers(c, s, a, b):
x = c.submit(inc, 1)
Expand All @@ -293,6 +291,7 @@ async def test_serializers(c, s, a, b):
assert "3 / 3" in progress.bar_text.value


@mock_widget()
@gen_tls_cluster(client=True)
async def test_tls(c, s, a, b):
x = c.submit(inc, 1)
Expand Down

0 comments on commit 181a3c3

Please sign in to comment.