Skip to content

Commit

Permalink
Merge pull request #2619 from Zac-HD/dtype-passthrough
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD authored Sep 24, 2020
2 parents 51622ee + 8f2c755 commit bb02569
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 183 deletions.
11 changes: 11 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
RELEASE_TYPE: minor

This release upgrades the :func:`~hypothesis.extra.numpy.from_dtype` strategy
to pass optional ``**kwargs`` to the inferred strategy, and upgrades the
:func:`~hypothesis.extra.numpy.arrays` strategy to accept an ``elements=kwargs``
dict to pass through to :func:`~hypothesis.extra.numpy.from_dtype`.

``arrays(floating_dtypes(), shape, elements={"min_value": -10, "max_value": 10})``
is a particularly useful pattern, as it allows for any floating dtype without
triggering the roundoff warning for smaller types or sacrificing variety for
larger types (:issue:`2552`).
91 changes: 69 additions & 22 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import math
import re
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Any, Mapping, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -41,49 +41,97 @@ class BroadcastableShapes(NamedTuple):


@st.defines_strategy(force_reusable_values=True)
def from_dtype(dtype: np.dtype) -> st.SearchStrategy[Any]:
"""Creates a strategy which can generate any value of the given dtype."""
def from_dtype(
dtype: np.dtype,
*,
alphabet: Optional[st.SearchStrategy[str]] = None,
min_size: int = 0,
max_size: Optional[int] = None,
min_value: Union[int, float, None] = None,
max_value: Union[int, float, None] = None,
allow_nan: Optional[bool] = None,
allow_infinity: Optional[bool] = None,
exclude_min: Optional[bool] = None,
exclude_max: Optional[bool] = None,
) -> st.SearchStrategy[Any]:
"""Creates a strategy which can generate any value of the given dtype.
Compatible ``**kwargs`` are passed to the inferred strategy function for
integers, floats, and strings. This allows you to customise the min and max
values, control the length or contents of strings, or exclude non-finite
numbers. This is particularly useful when kwargs are passed through from
:func:`arrays` which allow a variety of numeric dtypes, as it seamlessly
handles the ``width`` or representable bounds for you. See :issue:`2552`
for more detail.
"""
check_type(np.dtype, dtype, "dtype")
kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None}

# Compound datatypes, eg 'f4,f4,f4'
if dtype.names is not None:
# mapping np.void.type over a strategy is nonsense, so return now.
return st.tuples(*[from_dtype(dtype.fields[name][0]) for name in dtype.names])
subs = [from_dtype(dtype.fields[name][0], **kwargs) for name in dtype.names]
return st.tuples(*subs)

# Subarray datatypes, eg '(2, 3)i4'
if dtype.subdtype is not None:
subtype, shape = dtype.subdtype
return arrays(subtype, shape)
return arrays(subtype, shape, elements=kwargs)

def compat_kw(*args, **kw):
"""Update default args to the strategy with user-supplied keyword args."""
assert {"min_value", "max_value", "max_size"}.issuperset(kw)
for key in set(kwargs).intersection(kw):
msg = f"dtype {dtype!r} requires {key}={kwargs[key]!r} to be %s {kw[key]!r}"
if kw[key] is not None:
if key.startswith("min_") and kw[key] > kwargs[key]:
raise InvalidArgument(msg % ("at least",))
elif key.startswith("max_") and kw[key] < kwargs[key]:
raise InvalidArgument(msg % ("at most",))
kw.update({k: v for k, v in kwargs.items() if k in args or k in kw})
return kw

# Scalar datatypes
if dtype.kind == "b":
result = st.booleans() # type: SearchStrategy[Any]
elif dtype.kind == "f":
if dtype.itemsize == 2:
result = st.floats(width=16)
elif dtype.itemsize == 4:
result = st.floats(width=32)
else:
result = st.floats()
result = st.floats(
width=8 * dtype.itemsize,
**compat_kw(
"min_value",
"max_value",
"allow_nan",
"allow_infinity",
"exclude_min",
"exclude_max",
),
)
elif dtype.kind == "c":
# If anyone wants to add a `width` argument to `complex_numbers()`, we would
# accept a pull request and add passthrough support for magnitude bounds,
# but it's a low priority otherwise.
if dtype.itemsize == 8:
float32 = st.floats(width=32)
float32 = st.floats(width=32, **compat_kw("allow_nan", "allow_infinity"))
result = st.builds(complex, float32, float32)
else:
result = st.complex_numbers()
result = st.complex_numbers(**compat_kw("allow_nan", "allow_infinity"))
elif dtype.kind in ("S", "a"):
# Numpy strings are null-terminated; only allow round-trippable values.
# `itemsize == 0` means 'fixed length determined at array creation'
result = st.binary(max_size=dtype.itemsize or None).filter(
max_size = dtype.itemsize or None
result = st.binary(**compat_kw("min_size", max_size=max_size)).filter(
lambda b: b[-1:] != b"\0"
)
elif dtype.kind == "u":
result = st.integers(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1)
kw = compat_kw(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1)
result = st.integers(**kw)
elif dtype.kind == "i":
overflow = 2 ** (8 * dtype.itemsize - 1)
result = st.integers(min_value=-overflow, max_value=overflow - 1)
result = st.integers(**compat_kw(min_value=-overflow, max_value=overflow - 1))
elif dtype.kind == "U":
# Encoded in UTF-32 (four bytes/codepoint) and null-terminated
result = st.text(max_size=(dtype.itemsize or 0) // 4 or None).filter(
max_size = (dtype.itemsize or 0) // 4 or None
result = st.text(**compat_kw("alphabet", "min_size", max_size=max_size)).filter(
lambda b: b[-1:] != "\0"
)
elif dtype.kind in ("m", "M"):
Expand Down Expand Up @@ -303,7 +351,7 @@ def arrays(
dtype: Any,
shape: Union[int, Shape, st.SearchStrategy[Shape]],
*,
elements: Optional[st.SearchStrategy[Any]] = None,
elements: Optional[Union[SearchStrategy, Mapping[str, Any]]] = None,
fill: Optional[st.SearchStrategy[Any]] = None,
unique: bool = False,
) -> st.SearchStrategy[np.ndarray]:
Expand All @@ -317,8 +365,7 @@ def arrays(
* ``elements`` is a strategy for generating values to put in the array.
If it is None a suitable value will be inferred based on the dtype,
which may give any legal value (including eg ``NaN`` for floats).
If you have more specific requirements, you should supply your own
elements strategy.
If a mapping, it will be passed as ``**kwargs`` to ``from_dtype()``
* ``fill`` is a strategy that may be used to generate a single background
value for the array. If None, a suitable default will be inferred
based on the other arguments. If set to
Expand Down Expand Up @@ -391,7 +438,7 @@ def arrays(
)
# From here on, we're only dealing with values and it's relatively simple.
dtype = np.dtype(dtype)
if elements is None:
if elements is None or isinstance(elements, Mapping):
if dtype.kind in ("m", "M") and "[" not in dtype.str:
# For datetime and timedelta dtypes, we have a tricky situation -
# because they *may or may not* specify a unit as part of the dtype.
Expand All @@ -402,7 +449,7 @@ def arrays(
.map((dtype.str + "[{}]").format)
.flatmap(lambda d: arrays(d, shape=shape, fill=fill, unique=unique))
)
elements = from_dtype(dtype)
elements = from_dtype(dtype, **(elements or {}))
check_strategy(elements, "elements")
if isinstance(shape, int):
shape = (shape,)
Expand Down
6 changes: 6 additions & 0 deletions hypothesis-python/tests/numpy/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def e(a, **kwargs):
e(nps.from_dtype, dtype=float),
e(nps.from_dtype, dtype=numpy.int8),
e(nps.from_dtype, dtype=1),
e(nps.from_dtype, dtype=numpy.dtype("uint8"), min_value=-999),
e(nps.from_dtype, dtype=numpy.dtype("uint8"), max_value=999),
e(nps.from_dtype, dtype=numpy.dtype("int8"), min_value=-999),
e(nps.from_dtype, dtype=numpy.dtype("int8"), max_value=999),
e(nps.from_dtype, dtype=numpy.dtype("S4"), max_size=5),
e(nps.from_dtype, dtype=numpy.dtype("U4"), max_size=5),
e(nps.valid_tuple_axes, ndim=-1),
e(nps.valid_tuple_axes, ndim=2, min_size=-1),
e(nps.valid_tuple_axes, ndim=2, min_size=3, max_size=10),
Expand Down
Loading

0 comments on commit bb02569

Please sign in to comment.