Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup ipywidgets mocking #6918

Merged
merged 1 commit into from
Aug 31, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 44 additions & 45 deletions distributed/diagnostics/tests/test_widgets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
from __future__ import annotations

import contextlib
import re
from operator import add
from unittest import mock

import pytest
from packaging.version import parse as parse_version
from tlz import valmap

from distributed.client import wait
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
from distributed.worker import dumps_task

ipywidgets = pytest.importorskip("ipywidgets")

from ipykernel.comm import Comm
from ipywidgets import Widget

from distributed.diagnostics.progressbar import (
MultiProgressWidget,
ProgressWidget,
progress,
)

#################
# Utility stuff #
#################
Expand All @@ -31,68 +47,42 @@ def close(self, *args, **kwargs):
pass


_widget_attrs = {}
displayed = []
undefined = object()


def setup():
_widget_attrs["_comm_default"] = getattr(Widget, "_comm_default", undefined)
Widget._comm_default = lambda self: DummyComm()
if parse_version(ipywidgets.__version__) >= parse_version("8.0.0"):
display_attr = "_repr_mimebundle_"
else:
display_attr = "_ipython_display_"
_widget_attrs[display_attr] = getattr(Widget, display_attr)
_DISPLAY_ATTR = (
"_repr_mimebundle_"
if parse_version(ipywidgets.__version__) >= parse_version("8.0.0")
else "_ipython_display_"
)

def raise_not_implemented(*args, **kwargs):
raise NotImplementedError()

Widget._ipython_display_ = raise_not_implemented
def _comm_default(self):
return DummyComm()


def teardown():
for attr, value in _widget_attrs.items():
if value is undefined:
delattr(Widget, attr)
else:
setattr(Widget, attr, value)
@contextlib.contextmanager
def mock_widget():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to upstream this

with mock.patch(
f"ipywidgets.Widget.{_DISPLAY_ATTR}", side_effect=NotImplementedError
):
assert not hasattr(Widget, "_comm_default")
Widget._comm_default = _comm_default
try:
yield
finally:
del Widget._comm_default


def f(**kwargs):
pass


def clear_display():
global displayed
displayed = []


def record_display(*args):
displayed.extend(args)


# End code taken from ipywidgets

#####################
# Distributed stuff #
#####################

import re
from operator import add

from tlz import valmap

from distributed.client import wait
from distributed.diagnostics.progressbar import (
MultiProgressWidget,
ProgressWidget,
progress,
)
from distributed.utils_test import dec, gen_cluster, gen_tls_cluster, inc, throws
from distributed.worker import dumps_task


@mock_widget()
@gen_cluster(client=True)
async def test_progressbar_widget(c, s, a, b):
x = c.submit(inc, 1)
Expand All @@ -110,6 +100,7 @@ async def test_progressbar_widget(c, s, a, b):
await progress.listen()


@mock_widget()
@gen_cluster(client=True)
async def test_multi_progressbar_widget(c, s, a, b):
x1 = c.submit(inc, 1)
Expand Down Expand Up @@ -150,6 +141,7 @@ async def test_multi_progressbar_widget(c, s, a, b):
assert sorted(capacities, reverse=True) == capacities


@mock_widget()
@gen_cluster()
async def test_multi_progressbar_widget_after_close(s, a, b):
s.update_graph(
Expand Down Expand Up @@ -181,6 +173,7 @@ async def test_multi_progressbar_widget_after_close(s, a, b):
assert "x" in p.bars


@mock_widget()
def test_values(client):
L = [client.submit(inc, i) for i in range(5)]
wait(L)
Expand All @@ -198,6 +191,7 @@ def test_values(client):
assert p.status == "error"


@mock_widget()
def test_progressbar_done(client):
L = [client.submit(inc, i) for i in range(5)]
wait(L)
Expand All @@ -224,6 +218,7 @@ def test_progressbar_done(client):
assert repr(e) in p.elapsed_time.value


@mock_widget()
def test_progressbar_cancel(client):
import time

Expand All @@ -236,6 +231,7 @@ def test_progressbar_cancel(client):
assert p.bar.value == 0 # no tasks finish before cancel is called


@mock_widget()
@gen_cluster()
async def test_multibar_complete(s, a, b):
s.update_graph(
Expand Down Expand Up @@ -270,6 +266,7 @@ async def test_multibar_complete(s, a, b):
assert "2 / 2" in p.bar_texts["y"].value


@mock_widget()
def test_fast(client):
L = client.map(inc, range(100))
L2 = client.map(dec, L)
Expand All @@ -279,6 +276,7 @@ def test_fast(client):
assert set(p._last_response["all"]) == {"inc", "dec", "add"}


@mock_widget()
@gen_cluster(client=True, client_kwargs={"serializers": ["msgpack"]})
async def test_serializers(c, s, a, b):
x = c.submit(inc, 1)
Expand All @@ -293,6 +291,7 @@ async def test_serializers(c, s, a, b):
assert "3 / 3" in progress.bar_text.value


@mock_widget()
@gen_tls_cluster(client=True)
async def test_tls(c, s, a, b):
x = c.submit(inc, 1)
Expand Down