Skip to content

Commit

Permalink
Accept from_dtype(**kw)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Sep 23, 2020
1 parent b1b9a0c commit 3bf95ed
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 23 deletions.
11 changes: 11 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
RELEASE_TYPE: minor

This release upgrades the :func:`~hypothesis.extra.numpy.from_dtype` strategy
to pass optional ``**kwargs`` to the inferred strategy, and upgrades the
:func:`~hypothesis.extra.numpy.arrays` strategy to accept an ``elements=kwargs``
dict to pass through to :func:`~hypothesis.extra.numpy.from_dtype`.

``arrays(floating_dtypes(), shape, elements={"min_value": -10, "max_value": 10})``
is a particularly useful pattern, as it allows for any floating dtype without
triggering the roundoff warning for smaller types or sacrificing variety for
larger types (:issue:`2552`).
91 changes: 69 additions & 22 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import math
import re
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Any, Mapping, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -41,49 +41,97 @@ class BroadcastableShapes(NamedTuple):


@st.defines_strategy(force_reusable_values=True)
def from_dtype(dtype: np.dtype) -> st.SearchStrategy[Any]:
"""Creates a strategy which can generate any value of the given dtype."""
def from_dtype(
dtype: np.dtype,
*,
alphabet: Optional[st.SearchStrategy[str]] = None,
min_size: int = 0,
max_size: Optional[int] = None,
min_value: Union[int, float, None] = None,
max_value: Union[int, float, None] = None,
allow_nan: Optional[bool] = None,
allow_infinity: Optional[bool] = None,
exclude_min: Optional[bool] = None,
exclude_max: Optional[bool] = None,
) -> st.SearchStrategy[Any]:
"""Creates a strategy which can generate any value of the given dtype.
Compatible ``**kwargs`` are passed to the inferred strategy function for
integers, floats, and strings. This allows you to customise the min and max
values, control the length or contents of strings, or exclude non-finite
numbers. This is particularly useful when kwargs are passed through from
:func:`arrays` which allow a variety of numeric dtypes, as it seamlessly
handles the ``width`` or representable bounds for you. See :issue:`2552`
for more detail.
"""
check_type(np.dtype, dtype, "dtype")
kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None}

# Compound datatypes, eg 'f4,f4,f4'
if dtype.names is not None:
# mapping np.void.type over a strategy is nonsense, so return now.
return st.tuples(*[from_dtype(dtype.fields[name][0]) for name in dtype.names])
subs = [from_dtype(dtype.fields[name][0], **kwargs) for name in dtype.names]
return st.tuples(*subs)

# Subarray datatypes, eg '(2, 3)i4'
if dtype.subdtype is not None:
subtype, shape = dtype.subdtype
return arrays(subtype, shape)
return arrays(subtype, shape, elements=kwargs)

def compat_kw(*args, **kw):
"""Update default args to the strategy with user-supplied keyword args."""
assert {"min_value", "max_value", "max_size"}.issuperset(kw)
for key in set(kwargs).intersection(kw):
msg = f"dtype {dtype!r} requires {key}={kwargs[key]!r} to be %s {kw[key]!r}"
if kw[key] is not None:
if key.startswith("min_") and kw[key] > kwargs[key]:
raise InvalidArgument(msg % ("at least",))
elif key.startswith("max_") and kw[key] < kwargs[key]:
raise InvalidArgument(msg % ("at most",))
kw.update({k: v for k, v in kwargs.items() if k in args or k in kw})
return kw

# Scalar datatypes
if dtype.kind == "b":
result = st.booleans() # type: SearchStrategy[Any]
elif dtype.kind == "f":
if dtype.itemsize == 2:
result = st.floats(width=16)
elif dtype.itemsize == 4:
result = st.floats(width=32)
else:
result = st.floats()
result = st.floats(
width=8 * dtype.itemsize,
**compat_kw(
"min_value",
"max_value",
"allow_nan",
"allow_infinity",
"exclude_min",
"exclude_max",
),
)
elif dtype.kind == "c":
# If anyone wants to add a `width` argument to `complex_numbers()`, we would
# accept a pull request and add passthrough support for magnitude bounds,
# but it's a low priority otherwise.
if dtype.itemsize == 8:
float32 = st.floats(width=32)
float32 = st.floats(width=32, **compat_kw("allow_nan", "allow_infinity"))
result = st.builds(complex, float32, float32)
else:
result = st.complex_numbers()
result = st.complex_numbers(**compat_kw("allow_nan", "allow_infinity"))
elif dtype.kind in ("S", "a"):
# Numpy strings are null-terminated; only allow round-trippable values.
# `itemsize == 0` means 'fixed length determined at array creation'
result = st.binary(max_size=dtype.itemsize or None).filter(
max_size = dtype.itemsize or None
result = st.binary(**compat_kw("min_size", max_size=max_size)).filter(
lambda b: b[-1:] != b"\0"
)
elif dtype.kind == "u":
result = st.integers(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1)
kw = compat_kw(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1)
result = st.integers(**kw)
elif dtype.kind == "i":
overflow = 2 ** (8 * dtype.itemsize - 1)
result = st.integers(min_value=-overflow, max_value=overflow - 1)
result = st.integers(**compat_kw(min_value=-overflow, max_value=overflow - 1))
elif dtype.kind == "U":
# Encoded in UTF-32 (four bytes/codepoint) and null-terminated
result = st.text(max_size=(dtype.itemsize or 0) // 4 or None).filter(
max_size = (dtype.itemsize or 0) // 4 or None
result = st.text(**compat_kw("alphabet", "min_size", max_size=max_size)).filter(
lambda b: b[-1:] != "\0"
)
elif dtype.kind in ("m", "M"):
Expand Down Expand Up @@ -303,7 +351,7 @@ def arrays(
dtype: Any,
shape: Union[int, Shape, st.SearchStrategy[Shape]],
*,
elements: Optional[st.SearchStrategy[Any]] = None,
elements: Optional[Union[SearchStrategy, Mapping[str, Any]]] = None,
fill: Optional[st.SearchStrategy[Any]] = None,
unique: bool = False,
) -> st.SearchStrategy[np.ndarray]:
Expand All @@ -317,8 +365,7 @@ def arrays(
* ``elements`` is a strategy for generating values to put in the array.
If it is None a suitable value will be inferred based on the dtype,
which may give any legal value (including eg ``NaN`` for floats).
If you have more specific requirements, you should supply your own
elements strategy.
If a mapping, it will be passed as ``**kwargs`` to ``from_dtype()``
* ``fill`` is a strategy that may be used to generate a single background
value for the array. If None, a suitable default will be inferred
based on the other arguments. If set to
Expand Down Expand Up @@ -391,7 +438,7 @@ def arrays(
)
# From here on, we're only dealing with values and it's relatively simple.
dtype = np.dtype(dtype)
if elements is None:
if elements is None or isinstance(elements, Mapping):
if dtype.kind in ("m", "M") and "[" not in dtype.str:
# For datetime and timedelta dtypes, we have a tricky situation -
# because they *may or may not* specify a unit as part of the dtype.
Expand All @@ -402,7 +449,7 @@ def arrays(
.map((dtype.str + "[{}]").format)
.flatmap(lambda d: arrays(d, shape=shape, fill=fill, unique=unique))
)
elements = from_dtype(dtype)
elements = from_dtype(dtype, **(elements or {}))
check_strategy(elements, "elements")
if isinstance(shape, int):
shape = (shape,)
Expand Down
6 changes: 6 additions & 0 deletions hypothesis-python/tests/numpy/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def e(a, **kwargs):
e(nps.from_dtype, dtype=float),
e(nps.from_dtype, dtype=numpy.int8),
e(nps.from_dtype, dtype=1),
e(nps.from_dtype, dtype=numpy.dtype("uint8"), min_value=-999),
e(nps.from_dtype, dtype=numpy.dtype("uint8"), max_value=999),
e(nps.from_dtype, dtype=numpy.dtype("int8"), min_value=-999),
e(nps.from_dtype, dtype=numpy.dtype("int8"), max_value=999),
e(nps.from_dtype, dtype=numpy.dtype("S4"), max_size=5),
e(nps.from_dtype, dtype=numpy.dtype("U4"), max_size=5),
e(nps.valid_tuple_axes, ndim=-1),
e(nps.valid_tuple_axes, ndim=2, min_size=-1),
e(nps.valid_tuple_axes, ndim=2, min_size=3, max_size=10),
Expand Down
43 changes: 42 additions & 1 deletion hypothesis-python/tests/numpy/test_from_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

import numpy as np
import pytest
from tests.common.debug import find_any

from hypothesis import assume, given, settings, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.extra import numpy as nps
from hypothesis.strategies._internal import SearchStrategy
from tests.common.debug import find_any

STANDARD_TYPES = [
np.dtype(t)
Expand Down Expand Up @@ -180,3 +180,44 @@ def test_arrays_selects_consistent_time_unit(data, dtype_str):
def test_arrays_gives_useful_error_on_inconsistent_time_unit():
with pytest.raises(InvalidArgument, match="mismatch of time units in dtypes"):
nps.arrays("m8[Y]", 10, elements=nps.from_dtype(np.dtype("m8[D]"))).example()


@pytest.mark.parametrize(
"dtype, kwargs, pred",
[
# Floating point: bounds, exclusive bounds, and excluding nonfinites
(float, {"min_value": 1, "max_value": 2}, lambda x: 1 <= x <= 2),
(
float,
{"min_value": 1, "max_value": 2, "exclude_min": True, "exclude_max": True},
lambda x: 1 < x < 2,
),
(float, {"allow_nan": False}, lambda x: not np.isnan(x)),
(float, {"allow_infinity": False}, lambda x: not np.isinf(x)),
(float, {"allow_nan": False, "allow_infinity": False}, np.isfinite),
(complex, {"allow_nan": False}, lambda x: not np.isnan(x)),
(complex, {"allow_infinity": False}, lambda x: not np.isinf(x)),
(complex, {"allow_nan": False, "allow_infinity": False}, np.isfinite),
# Integer bounds, limited to the representable range
("int8", {"min_value": -1, "max_value": 1}, lambda x: -1 <= x <= 1),
("uint8", {"min_value": 1, "max_value": 2}, lambda x: 1 <= x <= 2),
# String arguments, bounding size and unicode alphabet
("S", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2),
("S4", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2),
("U", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2),
("U4", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2),
("U", {"alphabet": "abc"}, lambda x: set(x).issubset("abc")),
],
)
@given(data=st.data())
def test_from_dtype_with_kwargs(data, dtype, kwargs, pred):
value = data.draw(nps.from_dtype(np.dtype(dtype), **kwargs))
assert pred(value)


@given(nps.from_dtype(np.dtype("U20,uint8,float32"), min_size=1, allow_nan=False))
def test_customize_structured_dtypes(x):
name, age, score = x
assert len(name) >= 1
assert 0 <= age <= 255
assert not np.isnan(score)

0 comments on commit 3bf95ed

Please sign in to comment.