Skip to content

Commit

Permalink
test_serialize: if numpy is not present, only skip numpy tests, not all.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnavigator committed Mar 18, 2021
1 parent 18d7ffa commit c0552f9
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import pytest
from tlz import identity

np = pytest.importorskip("numpy")
try:
import numpy as np
except ImportError:
np = None

from dask.utils_test import inc

Expand Down Expand Up @@ -213,6 +216,7 @@ def test_empty_loads_deep():
assert isinstance(e2[0][0][0], Empty)


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes(kwargs):
for x in [
Expand All @@ -229,6 +233,7 @@ def test_serialize_bytes(kwargs):
assert str(x) == str(y)


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_serialize_list_compress():
pytest.importorskip("lz4")
x = np.ones(1000000)
Expand Down Expand Up @@ -440,7 +445,11 @@ def _(x):
(tuple([MyObj(None)]), True),
({("x", i): MyObj(5) for i in range(100)}, True),
(memoryview(b"hello"), True),
(memoryview(np.random.random((3, 4))), True),
pytest.param(
memoryview(np.random.random((3, 4))
if np is not None else b"skip np.random"),
True,
marks=pytest.mark.skipif(np is None, reason="Test needs numpy")),
],
)
def test_check_dask_serializable(data, is_serializable):
Expand All @@ -463,7 +472,14 @@ def test_serialize_lists(serializers):


@pytest.mark.parametrize(
"data_in", [memoryview(b"hello"), memoryview(np.random.random((3, 4)))]
"data_in",
[
memoryview(b"hello"),
pytest.param(
memoryview(np.random.random((3, 4))
if np is not None else b"skip np.random"),
marks=pytest.mark.skipif(np is None, reason="Test needs numpy"))
],
)
def test_deser_memoryview(data_in):
header, frames = serialize(data_in)
Expand All @@ -473,6 +489,7 @@ def test_deser_memoryview(data_in):
assert data_in == data_out


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_ser_memoryview_object():
data_in = memoryview(np.array(["hello"], dtype=object))
with pytest.raises(TypeError):
Expand Down

0 comments on commit c0552f9

Please sign in to comment.