Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customise nps.from_dtype() with keyword argument passthrough #2619

Merged
merged 2 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
RELEASE_TYPE: minor

This release upgrades the :func:`~hypothesis.extra.numpy.from_dtype` strategy
to pass optional ``**kwargs`` to the inferred strategy, and upgrades the
:func:`~hypothesis.extra.numpy.arrays` strategy to accept an ``elements=kwargs``
dict to pass through to :func:`~hypothesis.extra.numpy.from_dtype`.

``arrays(floating_dtypes(), shape, elements={"min_value": -10, "max_value": 10})``
is a particularly useful pattern, as it allows for any floating dtype without
triggering the roundoff warning for smaller types or sacrificing variety for
larger types (:issue:`2552`).
91 changes: 69 additions & 22 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import math
import re
from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union
from typing import Any, Mapping, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -41,49 +41,97 @@ class BroadcastableShapes(NamedTuple):


@st.defines_strategy(force_reusable_values=True)
def from_dtype(dtype: np.dtype) -> st.SearchStrategy[Any]:
"""Creates a strategy which can generate any value of the given dtype."""
def from_dtype(
dtype: np.dtype,
*,
alphabet: Optional[st.SearchStrategy[str]] = None,
min_size: int = 0,
max_size: Optional[int] = None,
min_value: Union[int, float, None] = None,
max_value: Union[int, float, None] = None,
allow_nan: Optional[bool] = None,
allow_infinity: Optional[bool] = None,
exclude_min: Optional[bool] = None,
exclude_max: Optional[bool] = None,
) -> st.SearchStrategy[Any]:
"""Creates a strategy which can generate any value of the given dtype.

Compatible ``**kwargs`` are passed to the inferred strategy function for
integers, floats, and strings. This allows you to customise the min and max
values, control the length or contents of strings, or exclude non-finite
numbers. This is particularly useful when kwargs are passed through from
:func:`arrays` which allow a variety of numeric dtypes, as it seamlessly
handles the ``width`` or representable bounds for you. See :issue:`2552`
for more detail.
"""
check_type(np.dtype, dtype, "dtype")
kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None}

# Compound datatypes, eg 'f4,f4,f4'
if dtype.names is not None:
# mapping np.void.type over a strategy is nonsense, so return now.
return st.tuples(*[from_dtype(dtype.fields[name][0]) for name in dtype.names])
subs = [from_dtype(dtype.fields[name][0], **kwargs) for name in dtype.names]
return st.tuples(*subs)

# Subarray datatypes, eg '(2, 3)i4'
if dtype.subdtype is not None:
subtype, shape = dtype.subdtype
return arrays(subtype, shape)
return arrays(subtype, shape, elements=kwargs)

def compat_kw(*args, **kw):
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
"""Update default args to the strategy with user-supplied keyword args."""
assert {"min_value", "max_value", "max_size"}.issuperset(kw)
for key in set(kwargs).intersection(kw):
msg = f"dtype {dtype!r} requires {key}={kwargs[key]!r} to be %s {kw[key]!r}"
if kw[key] is not None:
if key.startswith("min_") and kw[key] > kwargs[key]:
raise InvalidArgument(msg % ("at least",))
elif key.startswith("max_") and kw[key] < kwargs[key]:
raise InvalidArgument(msg % ("at most",))
kw.update({k: v for k, v in kwargs.items() if k in args or k in kw})
return kw

# Scalar datatypes
if dtype.kind == "b":
result = st.booleans() # type: SearchStrategy[Any]
elif dtype.kind == "f":
if dtype.itemsize == 2:
result = st.floats(width=16)
elif dtype.itemsize == 4:
result = st.floats(width=32)
else:
result = st.floats()
result = st.floats(
width=8 * dtype.itemsize,
**compat_kw(
"min_value",
"max_value",
"allow_nan",
"allow_infinity",
"exclude_min",
"exclude_max",
),
)
elif dtype.kind == "c":
# If anyone wants to add a `width` argument to `complex_numbers()`, we would
# accept a pull request and add passthrough support for magnitude bounds,
# but it's a low priority otherwise.
if dtype.itemsize == 8:
float32 = st.floats(width=32)
float32 = st.floats(width=32, **compat_kw("allow_nan", "allow_infinity"))
result = st.builds(complex, float32, float32)
else:
result = st.complex_numbers()
result = st.complex_numbers(**compat_kw("allow_nan", "allow_infinity"))
elif dtype.kind in ("S", "a"):
# Numpy strings are null-terminated; only allow round-trippable values.
# `itemsize == 0` means 'fixed length determined at array creation'
result = st.binary(max_size=dtype.itemsize or None).filter(
max_size = dtype.itemsize or None
result = st.binary(**compat_kw("min_size", max_size=max_size)).filter(
lambda b: b[-1:] != b"\0"
)
elif dtype.kind == "u":
result = st.integers(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1)
kw = compat_kw(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1)
result = st.integers(**kw)
elif dtype.kind == "i":
overflow = 2 ** (8 * dtype.itemsize - 1)
result = st.integers(min_value=-overflow, max_value=overflow - 1)
result = st.integers(**compat_kw(min_value=-overflow, max_value=overflow - 1))
elif dtype.kind == "U":
# Encoded in UTF-32 (four bytes/codepoint) and null-terminated
result = st.text(max_size=(dtype.itemsize or 0) // 4 or None).filter(
max_size = (dtype.itemsize or 0) // 4 or None
result = st.text(**compat_kw("alphabet", "min_size", max_size=max_size)).filter(
lambda b: b[-1:] != "\0"
)
elif dtype.kind in ("m", "M"):
Expand Down Expand Up @@ -303,7 +351,7 @@ def arrays(
dtype: Any,
shape: Union[int, Shape, st.SearchStrategy[Shape]],
*,
elements: Optional[st.SearchStrategy[Any]] = None,
elements: Optional[Union[SearchStrategy, Mapping[str, Any]]] = None,
fill: Optional[st.SearchStrategy[Any]] = None,
unique: bool = False,
) -> st.SearchStrategy[np.ndarray]:
Expand All @@ -317,8 +365,7 @@ def arrays(
* ``elements`` is a strategy for generating values to put in the array.
If it is None a suitable value will be inferred based on the dtype,
which may give any legal value (including eg ``NaN`` for floats).
If you have more specific requirements, you should supply your own
elements strategy.
If a mapping, it will be passed as ``**kwargs`` to ``from_dtype()``
* ``fill`` is a strategy that may be used to generate a single background
value for the array. If None, a suitable default will be inferred
based on the other arguments. If set to
Expand Down Expand Up @@ -391,7 +438,7 @@ def arrays(
)
# From here on, we're only dealing with values and it's relatively simple.
dtype = np.dtype(dtype)
if elements is None:
if elements is None or isinstance(elements, Mapping):
if dtype.kind in ("m", "M") and "[" not in dtype.str:
# For datetime and timedelta dtypes, we have a tricky situation -
# because they *may or may not* specify a unit as part of the dtype.
Expand All @@ -402,7 +449,7 @@ def arrays(
.map((dtype.str + "[{}]").format)
.flatmap(lambda d: arrays(d, shape=shape, fill=fill, unique=unique))
)
elements = from_dtype(dtype)
elements = from_dtype(dtype, **(elements or {}))
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
check_strategy(elements, "elements")
if isinstance(shape, int):
shape = (shape,)
Expand Down
6 changes: 6 additions & 0 deletions hypothesis-python/tests/numpy/test_argument_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ def e(a, **kwargs):
e(nps.from_dtype, dtype=float),
e(nps.from_dtype, dtype=numpy.int8),
e(nps.from_dtype, dtype=1),
e(nps.from_dtype, dtype=numpy.dtype("uint8"), min_value=-999),
e(nps.from_dtype, dtype=numpy.dtype("uint8"), max_value=999),
e(nps.from_dtype, dtype=numpy.dtype("int8"), min_value=-999),
e(nps.from_dtype, dtype=numpy.dtype("int8"), max_value=999),
e(nps.from_dtype, dtype=numpy.dtype("S4"), max_size=5),
e(nps.from_dtype, dtype=numpy.dtype("U4"), max_size=5),
e(nps.valid_tuple_axes, ndim=-1),
e(nps.valid_tuple_axes, ndim=2, min_size=-1),
e(nps.valid_tuple_axes, ndim=2, min_size=3, max_size=10),
Expand Down
Loading