From 942b062f95196d7e32497deb54bcaec211de7f8f Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Wed, 16 Sep 2020 13:38:51 +1000 Subject: [PATCH] Accept from_dtype(**kw) --- hypothesis-python/RELEASE.rst | 11 +++ .../src/hypothesis/extra/numpy.py | 86 ++++++++++++++----- .../tests/numpy/test_argument_validation.py | 6 ++ .../tests/numpy/test_from_dtype.py | 40 ++++++++- 4 files changed, 122 insertions(+), 21 deletions(-) create mode 100644 hypothesis-python/RELEASE.rst diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 00000000000..edf221b4fd4 --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,11 @@ +RELEASE_TYPE: minor + +This release upgrades the :func:`~hypothesis.extra.numpy.from_dtype` strategy +to pass optional ``**kwargs`` to the inferred strategy, and upgrades the +:func:`~hypothesis.extra.numpy.arrays` strategy to accept an ``elements=kwargs`` +dict to pass through to :func:`~hypothesis.extra.numpy.from_dtype`. + +``arrays(floating_dtypes(), shape, elements={"min_value": -10, "max_value": 10})`` +is a particularly useful pattern, as it allows for any floating dtype without +triggering the roundoff warning for smaller types or sacrificing variety for +larger types (:issue:`2552`). diff --git a/hypothesis-python/src/hypothesis/extra/numpy.py b/hypothesis-python/src/hypothesis/extra/numpy.py index b300e076916..49bb05e19ea 100644 --- a/hypothesis-python/src/hypothesis/extra/numpy.py +++ b/hypothesis-python/src/hypothesis/extra/numpy.py @@ -15,7 +15,7 @@ import math import re -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, Mapping, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -41,30 +41,74 @@ @st.defines_strategy(force_reusable_values=True) -def from_dtype(dtype: np.dtype) -> st.SearchStrategy[Any]: - """Creates a strategy which can generate any value of the given dtype.""" +def from_dtype( + dtype: np.dtype, + *, + alphabet: Optional[st.SearchStrategy[str]] = None, + min_size: int = 0, + max_size: Optional[int] = None, + min_value: Union[int, float, None] = None, + max_value: Union[int, float, None] = None, + allow_nan: Optional[bool] = None, + allow_infinity: Optional[bool] = None, + exclude_min: Optional[bool] = None, + exclude_max: Optional[bool] = None, +) -> st.SearchStrategy[Any]: + """Creates a strategy which can generate any value of the given dtype. + + Compatible ``**kwargs`` are passed to the inferred strategy function for + integers, floats, and strings. This allows you to customise the min and max + values, control the length or contents of strings, or exclude non-finite + numbers. This is particularly useful when kwargs are passed through from + :func:`arrays` which allow a variety of numeric dtypes, as it seamlessly + handles the ``width`` or representable bounds for you. See :issue:`2552` + for more detail. + """ check_type(np.dtype, dtype, "dtype") + kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None} + # Compound datatypes, eg 'f4,f4,f4' if dtype.names is not None: # mapping np.void.type over a strategy is nonsense, so return now. - return st.tuples(*[from_dtype(dtype.fields[name][0]) for name in dtype.names]) + subs = [from_dtype(dtype.fields[name][0], **kwargs) for name in dtype.names] + return st.tuples(*subs) # Subarray datatypes, eg '(2, 3)i4' if dtype.subdtype is not None: subtype, shape = dtype.subdtype - return arrays(subtype, shape) + return arrays(subtype, shape, elements=kwargs) + + def compat_kw(*args, **kw): + """Update default args to the strategy with user-supplied keyword args.""" + assert {"min_value", "max_value", "max_size"}.issuperset(kw) + for key in set(kwargs).intersection(kw): + msg = f"dtype {dtype!r} requires {key}={kwargs[key]!r} to be %s {kw[key]!r}" + if kw[key] is not None: + if key.startswith("min_") and kw[key] > kwargs[key]: + raise InvalidArgument(msg % ("at least",)) + elif key.startswith("max_") and kw[key] < kwargs[key]: + raise InvalidArgument(msg % ("at most",)) + kw.update({k: v for k, v in kwargs.items() if k in args or k in kw}) + return kw # Scalar datatypes if dtype.kind == "b": result = st.booleans() # type: SearchStrategy[Any] elif dtype.kind == "f": - if dtype.itemsize == 2: - result = st.floats(width=16) - elif dtype.itemsize == 4: - result = st.floats(width=32) - else: - result = st.floats() + result = st.floats( + width=8 * dtype.itemsize, + **compat_kw( + "min_value", + "max_value", + "allow_nan", + "allow_infinity", + "exclude_min", + "exclude_max", + ), + ) elif dtype.kind == "c": + # If anyone wants to add a `width` argument to `complex_numbers()`, + # we would accept a pull request but it's a low priority otherwise. if dtype.itemsize == 8: float32 = st.floats(width=32) result = st.builds(complex, float32, float32) @@ -73,17 +117,20 @@ def from_dtype(dtype: np.dtype) -> st.SearchStrategy[Any]: elif dtype.kind in ("S", "a"): # Numpy strings are null-terminated; only allow round-trippable values. # `itemsize == 0` means 'fixed length determined at array creation' - result = st.binary(max_size=dtype.itemsize or None).filter( + max_size = dtype.itemsize or None + result = st.binary(**compat_kw("min_size", max_size=max_size)).filter( lambda b: b[-1:] != b"\0" ) elif dtype.kind == "u": - result = st.integers(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1) + kw = compat_kw(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1) + result = st.integers(**kw) elif dtype.kind == "i": overflow = 2 ** (8 * dtype.itemsize - 1) - result = st.integers(min_value=-overflow, max_value=overflow - 1) + result = st.integers(**compat_kw(min_value=-overflow, max_value=overflow - 1)) elif dtype.kind == "U": # Encoded in UTF-32 (four bytes/codepoint) and null-terminated - result = st.text(max_size=(dtype.itemsize or 0) // 4 or None).filter( + max_size = (dtype.itemsize or 0) // 4 or None + result = st.text(**compat_kw("alphabet", "min_size", max_size=max_size)).filter( lambda b: b[-1:] != "\0" ) elif dtype.kind in ("m", "M"): @@ -303,7 +350,7 @@ def arrays( dtype: Any, shape: Union[int, Shape, st.SearchStrategy[Shape]], *, - elements: Optional[st.SearchStrategy[Any]] = None, + elements: Optional[Union[SearchStrategy, Mapping[str, Any]]] = None, fill: Optional[st.SearchStrategy[Any]] = None, unique: bool = False ) -> st.SearchStrategy[np.ndarray]: @@ -317,8 +364,7 @@ def arrays( * ``elements`` is a strategy for generating values to put in the array. If it is None a suitable value will be inferred based on the dtype, which may give any legal value (including eg ``NaN`` for floats). - If you have more specific requirements, you should supply your own - elements strategy. + If a mapping, it will be passed as ``**kwargs`` to ``from_dtype()`` * ``fill`` is a strategy that may be used to generate a single background value for the array. If None, a suitable default will be inferred based on the other arguments. If set to @@ -391,7 +437,7 @@ def arrays( ) # From here on, we're only dealing with values and it's relatively simple. dtype = np.dtype(dtype) - if elements is None: + if elements is None or isinstance(elements, Mapping): if dtype.kind in ("m", "M") and "[" not in dtype.str: # For datetime and timedelta dtypes, we have a tricky situation - # because they *may or may not* specify a unit as part of the dtype. @@ -402,7 +448,7 @@ def arrays( .map((dtype.str + "[{}]").format) .flatmap(lambda d: arrays(d, shape=shape, fill=fill, unique=unique)) ) - elements = from_dtype(dtype) + elements = from_dtype(dtype, **(elements or {})) check_strategy(elements, "elements") if isinstance(shape, int): shape = (shape,) diff --git a/hypothesis-python/tests/numpy/test_argument_validation.py b/hypothesis-python/tests/numpy/test_argument_validation.py index 939cf65b46d..680a7672da8 100644 --- a/hypothesis-python/tests/numpy/test_argument_validation.py +++ b/hypothesis-python/tests/numpy/test_argument_validation.py @@ -65,6 +65,12 @@ def e(a, **kwargs): e(nps.from_dtype, dtype=float), e(nps.from_dtype, dtype=numpy.int8), e(nps.from_dtype, dtype=1), + e(nps.from_dtype, dtype=numpy.dtype("uint8"), min_value=-999), + e(nps.from_dtype, dtype=numpy.dtype("uint8"), max_value=999), + e(nps.from_dtype, dtype=numpy.dtype("int8"), min_value=-999), + e(nps.from_dtype, dtype=numpy.dtype("int8"), max_value=999), + e(nps.from_dtype, dtype=numpy.dtype("S4"), max_size=5), + e(nps.from_dtype, dtype=numpy.dtype("U4"), max_size=5), e(nps.valid_tuple_axes, ndim=-1), e(nps.valid_tuple_axes, ndim=2, min_size=-1), e(nps.valid_tuple_axes, ndim=2, min_size=3, max_size=10), diff --git a/hypothesis-python/tests/numpy/test_from_dtype.py b/hypothesis-python/tests/numpy/test_from_dtype.py index 6d3df28cfdd..b4957a66a99 100644 --- a/hypothesis-python/tests/numpy/test_from_dtype.py +++ b/hypothesis-python/tests/numpy/test_from_dtype.py @@ -15,12 +15,12 @@ import numpy as np import pytest -from tests.common.debug import find_any from hypothesis import assume, given, settings, strategies as st from hypothesis.errors import InvalidArgument from hypothesis.extra import numpy as nps from hypothesis.strategies._internal import SearchStrategy +from tests.common.debug import find_any STANDARD_TYPES = [ np.dtype(t) @@ -180,3 +180,41 @@ def test_arrays_selects_consistent_time_unit(data, dtype_str): def test_arrays_gives_useful_error_on_inconsistent_time_unit(): with pytest.raises(InvalidArgument, match="mismatch of time units in dtypes"): nps.arrays("m8[Y]", 10, elements=nps.from_dtype(np.dtype("m8[D]"))).example() + + +@pytest.mark.parametrize( + "dtype, kwargs, pred", + [ + # Floating point: bounds, exclusive bounds, and excluding nonfinites + (float, {"min_value": 1, "max_value": 2}, lambda x: 1 <= x <= 2), + ( + float, + {"min_value": 1, "max_value": 2, "exclude_min": True, "exclude_max": True}, + lambda x: 1 < x < 2, + ), + (float, {"allow_nan": False}, lambda x: not np.isnan(x)), + (float, {"allow_infinity": False}, lambda x: not np.isinf(x)), + (float, {"allow_nan": False, "allow_infinity": False}, np.isfinite), + # Integer bounds, limited to the representable range + ("int8", {"min_value": -1, "max_value": 1}, lambda x: -1 <= x <= 1), + ("uint8", {"min_value": 1, "max_value": 2}, lambda x: 1 <= x <= 2), + # String arguments, bounding size and unicode alphabet + ("S", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("S4", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("U", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("U4", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("U", {"alphabet": "abc"}, lambda x: set(x).issubset("abc")), + ], +) +@given(data=st.data()) +def test_from_dtype_with_kwargs(data, dtype, kwargs, pred): + value = data.draw(nps.from_dtype(np.dtype(dtype), **kwargs)) + assert pred(value) + + +@given(nps.from_dtype(np.dtype("U20,uint8,float32"), min_size=1, allow_nan=False)) +def test_customize_structured_dtypes(x): + name, age, score = x + assert len(name) >= 1 + assert 0 <= age <= 255 + assert not np.isnan(score)