From d2e013639a66ccbe3566e1eecbcf0da1935f6fbb Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Thu, 24 Sep 2020 10:34:40 +1000 Subject: [PATCH 1/2] Split up test module --- .../tests/numpy/test_from_dtype.py | 182 ++++++++++++++++++ .../tests/numpy/test_gen_data.py | 161 ---------------- 2 files changed, 182 insertions(+), 161 deletions(-) create mode 100644 hypothesis-python/tests/numpy/test_from_dtype.py diff --git a/hypothesis-python/tests/numpy/test_from_dtype.py b/hypothesis-python/tests/numpy/test_from_dtype.py new file mode 100644 index 0000000000..6d3df28cfd --- /dev/null +++ b/hypothesis-python/tests/numpy/test_from_dtype.py @@ -0,0 +1,182 @@ +# This file is part of Hypothesis, which may be found at +# https://github.com/HypothesisWorks/hypothesis/ +# +# Most of this work is copyright (C) 2013-2020 David R. MacIver +# (david@drmaciver.com), but it contains contributions by others. See +# CONTRIBUTING.rst for a full list of people who may hold copyright, and +# consult the git log if you need to determine who owns an individual +# contribution. +# +# This Source Code Form is subject to the terms of the Mozilla Public License, +# v. 2.0. If a copy of the MPL was not distributed with this file, You can +# obtain one at https://mozilla.org/MPL/2.0/. +# +# END HEADER + +import numpy as np +import pytest +from tests.common.debug import find_any + +from hypothesis import assume, given, settings, strategies as st +from hypothesis.errors import InvalidArgument +from hypothesis.extra import numpy as nps +from hypothesis.strategies._internal import SearchStrategy + +STANDARD_TYPES = [ + np.dtype(t) + for t in ( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float", + "float16", + "float32", + "float64", + "complex64", + "complex128", + "datetime64", + "timedelta64", + bool, + str, + bytes, + ) +] + + +@given(nps.nested_dtypes()) +def test_strategies_for_standard_dtypes_have_reusable_values(dtype): + assert nps.from_dtype(dtype).has_reusable_values + + +@pytest.mark.parametrize("t", STANDARD_TYPES) +def test_produces_instances(t): + @given(nps.from_dtype(t)) + def test_is_t(x): + assert isinstance(x, t.type) + assert x.dtype.kind == t.kind + + test_is_t() + + +@settings(max_examples=100) +@given(nps.nested_dtypes(max_itemsize=400), st.data()) +def test_infer_strategy_from_dtype(dtype, data): + # Given a dtype + assert isinstance(dtype, np.dtype) + # We can infer a strategy + strat = nps.from_dtype(dtype) + assert isinstance(strat, SearchStrategy) + # And use it to fill an array of that dtype + data.draw(nps.arrays(dtype, 10, elements=strat)) + + +@given(st.data()) +def test_can_cast_for_scalars(data): + # Note: this only passes with castable datatypes, certain dtype + # combinations will result in an error if numpy is not able to cast them. + dt_elements = np.dtype(data.draw(st.sampled_from(["bool", "i2"]))) + dt_desired = np.dtype( + data.draw(st.sampled_from(["i2", "float32", "float64"])) + ) + result = data.draw( + nps.arrays(dtype=dt_desired, elements=nps.from_dtype(dt_elements), shape=()) + ) + assert isinstance(result, np.ndarray) + assert result.dtype == dt_desired + + +@given(st.data()) +def test_unicode_string_dtypes_generate_unicode_strings(data): + dt = data.draw(nps.unicode_string_dtypes()) + result = data.draw(nps.from_dtype(dt)) + assert isinstance(result, str) + + +@given(nps.arrays(dtype="U99", shape=(10,))) +def test_can_unicode_strings_without_decode_error(arr): + # See https://github.com/numpy/numpy/issues/15363 + pass + + +@pytest.mark.xfail(strict=False, reason="mitigation for issue above") +def test_unicode_string_dtypes_need_not_be_utf8(): + def cannot_encode(string): + try: + string.encode("utf-8") + return False + except UnicodeEncodeError: + return True + + find_any(nps.from_dtype(np.dtype("U")), cannot_encode) + + +@given(st.data()) +def test_byte_string_dtypes_generate_unicode_strings(data): + dt = data.draw(nps.byte_string_dtypes()) + result = data.draw(nps.from_dtype(dt)) + assert isinstance(result, bytes) + + +@pytest.mark.parametrize("dtype", ["U", "S", "a"]) +def test_unsized_strings_length_gt_one(dtype): + # See https://github.com/HypothesisWorks/hypothesis/issues/2229 + find_any(nps.arrays(dtype=dtype, shape=1), lambda arr: len(arr[0]) >= 2) + + +@given( + st.data(), + st.builds( + "{}[{}]".format, + st.sampled_from(("datetime64", "timedelta64")), + st.sampled_from(nps.TIME_RESOLUTIONS), + ).map(np.dtype), +) +def test_inferring_from_time_dtypes_gives_same_dtype(data, dtype): + ex = data.draw(nps.from_dtype(dtype)) + assert dtype == ex.dtype + + +@given(st.data(), nps.byte_string_dtypes() | nps.unicode_string_dtypes()) +def test_inferred_string_strategies_roundtrip(data, dtype): + # Check that we never generate too-long or nul-terminated strings, which + # cannot be read back out of an array. + arr = np.zeros(shape=1, dtype=dtype) + ex = data.draw(nps.from_dtype(arr.dtype)) + arr[0] = ex + assert arr[0] == ex + + +@given(st.data(), nps.scalar_dtypes()) +def test_all_inferred_scalar_strategies_roundtrip(data, dtype): + # We only check scalars here, because record/compound/nested dtypes always + # give an array of np.void objects. We're interested in whether scalar + # values are safe, not known type coercion. + arr = np.zeros(shape=1, dtype=dtype) + ex = data.draw(nps.from_dtype(arr.dtype)) + assume(ex == ex) # If not, the roundtrip test *should* fail! (eg NaN) + arr[0] = ex + assert arr[0] == ex + + +@pytest.mark.parametrize("dtype_str", ["m8", "M8"]) +@given(data=st.data()) +def test_from_dtype_works_without_time_unit(data, dtype_str): + arr = data.draw(nps.from_dtype(np.dtype(dtype_str))) + assert (dtype_str + "[") in arr.dtype.str + + +@pytest.mark.parametrize("dtype_str", ["m8", "M8"]) +@given(data=st.data()) +def test_arrays_selects_consistent_time_unit(data, dtype_str): + arr = data.draw(nps.arrays(dtype_str, 10)) + assert (dtype_str + "[") in arr.dtype.str + + +def test_arrays_gives_useful_error_on_inconsistent_time_unit(): + with pytest.raises(InvalidArgument, match="mismatch of time units in dtypes"): + nps.arrays("m8[Y]", 10, elements=nps.from_dtype(np.dtype("m8[D]"))).example() diff --git a/hypothesis-python/tests/numpy/test_gen_data.py b/hypothesis-python/tests/numpy/test_gen_data.py index 3edd249579..6381f6ec4f 100644 --- a/hypothesis-python/tests/numpy/test_gen_data.py +++ b/hypothesis-python/tests/numpy/test_gen_data.py @@ -23,53 +23,11 @@ from hypothesis import HealthCheck, assume, given, note, settings, strategies as st from hypothesis.errors import InvalidArgument, Unsatisfiable from hypothesis.extra import numpy as nps -from hypothesis.strategies._internal import SearchStrategy from tests.common.debug import find_any, minimal from tests.common.utils import fails_with, flaky ANY_SHAPE = nps.array_shapes(min_dims=0, max_dims=32, min_side=0, max_side=32) ANY_NONZERO_SHAPE = nps.array_shapes(min_dims=0, max_dims=32, min_side=1, max_side=32) -STANDARD_TYPES = list( - map( - np.dtype, - [ - "int8", - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", - "float", - "float16", - "float32", - "float64", - "complex64", - "complex128", - "datetime64", - "timedelta64", - bool, - str, - bytes, - ], - ) -) - - -@given(nps.nested_dtypes()) -def test_strategies_for_standard_dtypes_have_reusable_values(dtype): - assert nps.from_dtype(dtype).has_reusable_values - - -@pytest.mark.parametrize("t", STANDARD_TYPES) -def test_produces_instances(t): - @given(nps.from_dtype(t)) - def test_is_t(x): - assert isinstance(x, t.type) - assert x.dtype.kind == t.kind - - test_is_t() @given(nps.arrays(float, ())) @@ -211,18 +169,6 @@ def test_can_generate_data_compound_dtypes(arr): assert isinstance(arr, np.ndarray) -@settings(max_examples=100) -@given(nps.nested_dtypes(max_itemsize=400), st.data()) -def test_infer_strategy_from_dtype(dtype, data): - # Given a dtype - assert isinstance(dtype, np.dtype) - # We can infer a strategy - strat = nps.from_dtype(dtype) - assert isinstance(strat, SearchStrategy) - # And use it to fill an array of that dtype - data.draw(nps.arrays(dtype, 10, elements=strat)) - - @given(nps.nested_dtypes()) def test_np_dtype_is_idempotent(dtype): assert dtype == np.dtype(dtype) @@ -279,21 +225,6 @@ def test_can_draw_arrays_from_scalars(data): assert result.dtype == dt -@given(st.data()) -def test_can_cast_for_scalars(data): - # Note: this only passes with castable datatypes, certain dtype - # combinations will result in an error if numpy is not able to cast them. - dt_elements = np.dtype(data.draw(st.sampled_from(["bool", "i2"]))) - dt_desired = np.dtype( - data.draw(st.sampled_from(["i2", "float32", "float64"])) - ) - result = data.draw( - nps.arrays(dtype=dt_desired, elements=nps.from_dtype(dt_elements), shape=()) - ) - assert isinstance(result, np.ndarray) - assert result.dtype == dt_desired - - @given(st.data()) def test_can_cast_for_arrays(data): # Note: this only passes with castable datatypes, certain dtype @@ -311,44 +242,6 @@ def test_can_cast_for_arrays(data): assert result.dtype == dt_desired -@given(st.data()) -def test_unicode_string_dtypes_generate_unicode_strings(data): - dt = data.draw(nps.unicode_string_dtypes()) - result = data.draw(nps.from_dtype(dt)) - assert isinstance(result, str) - - -@given(nps.arrays(dtype="U99", shape=(10,))) -def test_can_unicode_strings_without_decode_error(arr): - # See https://github.com/numpy/numpy/issues/15363 - pass - - -@pytest.mark.xfail(strict=False, reason="mitigation for issue above") -def test_unicode_string_dtypes_need_not_be_utf8(): - def cannot_encode(string): - try: - string.encode("utf-8") - return False - except UnicodeEncodeError: - return True - - find_any(nps.from_dtype(np.dtype("U")), cannot_encode) - - -@given(st.data()) -def test_byte_string_dtypes_generate_unicode_strings(data): - dt = data.draw(nps.byte_string_dtypes()) - result = data.draw(nps.from_dtype(dt)) - assert isinstance(result, bytes) - - -@pytest.mark.parametrize("dtype", ["U", "S", "a"]) -def test_unsized_strings_length_gt_one(dtype): - # See https://github.com/HypothesisWorks/hypothesis/issues/2229 - find_any(nps.arrays(dtype=dtype, shape=1), lambda arr: len(arr[0]) >= 2) - - @given(nps.arrays(dtype="int8", shape=st.integers(0, 20), unique=True)) def test_array_values_are_unique(arr): assert len(set(arr)) == len(arr) @@ -425,41 +318,6 @@ def test_may_not_fill_with_non_nan_when_unique_is_set_and_type_is_not_number(arr pass -@given( - st.data(), - st.builds( - "{}[{}]".format, - st.sampled_from(("datetime64", "timedelta64")), - st.sampled_from(nps.TIME_RESOLUTIONS), - ).map(np.dtype), -) -def test_inferring_from_time_dtypes_gives_same_dtype(data, dtype): - ex = data.draw(nps.from_dtype(dtype)) - assert dtype == ex.dtype - - -@given(st.data(), nps.byte_string_dtypes() | nps.unicode_string_dtypes()) -def test_inferred_string_strategies_roundtrip(data, dtype): - # Check that we never generate too-long or nul-terminated strings, which - # cannot be read back out of an array. - arr = np.zeros(shape=1, dtype=dtype) - ex = data.draw(nps.from_dtype(arr.dtype)) - arr[0] = ex - assert arr[0] == ex - - -@given(st.data(), nps.scalar_dtypes()) -def test_all_inferred_scalar_strategies_roundtrip(data, dtype): - # We only check scalars here, because record/compound/nested dtypes always - # give an array of np.void objects. We're interested in whether scalar - # values are safe, not known type coercion. - arr = np.zeros(shape=1, dtype=dtype) - ex = data.draw(nps.from_dtype(arr.dtype)) - assume(ex == ex) # If not, the roundtrip test *should* fail! (eg NaN) - arr[0] = ex - assert arr[0] == ex - - @pytest.mark.parametrize("fill", [False, True]) @fails_with(InvalidArgument) @given(st.data()) @@ -1259,25 +1117,6 @@ def test_basic_indices_generate_valid_indexers( assert np.shares_memory(view, array) -@pytest.mark.parametrize("dtype_str", ["m8", "M8"]) -@given(data=st.data()) -def test_from_dtype_works_without_time_unit(data, dtype_str): - arr = data.draw(nps.from_dtype(np.dtype(dtype_str))) - assert (dtype_str + "[") in arr.dtype.str - - -@pytest.mark.parametrize("dtype_str", ["m8", "M8"]) -@given(data=st.data()) -def test_arrays_selects_consistent_time_unit(data, dtype_str): - arr = data.draw(nps.arrays(dtype_str, 10)) - assert (dtype_str + "[") in arr.dtype.str - - -def test_arrays_gives_useful_error_on_inconsistent_time_unit(): - with pytest.raises(InvalidArgument, match="mismatch of time units in dtypes"): - nps.arrays("m8[Y]", 10, elements=nps.from_dtype(np.dtype("m8[D]"))).example() - - # addresses https://github.com/HypothesisWorks/hypothesis/issues/2582 @given( nps.arrays( From 8f2c755a5308f6b94fbfa810581a8eee375664f0 Mon Sep 17 00:00:00 2001 From: Zac-HD Date: Thu, 24 Sep 2020 10:34:40 +1000 Subject: [PATCH 2/2] Accept from_dtype(**kw) --- hypothesis-python/RELEASE.rst | 11 +++ .../src/hypothesis/extra/numpy.py | 91 ++++++++++++++----- .../tests/numpy/test_argument_validation.py | 6 ++ .../tests/numpy/test_from_dtype.py | 43 ++++++++- .../tests/numpy/test_gen_data.py | 23 +++++ 5 files changed, 151 insertions(+), 23 deletions(-) create mode 100644 hypothesis-python/RELEASE.rst diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..edf221b4fd --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,11 @@ +RELEASE_TYPE: minor + +This release upgrades the :func:`~hypothesis.extra.numpy.from_dtype` strategy +to pass optional ``**kwargs`` to the inferred strategy, and upgrades the +:func:`~hypothesis.extra.numpy.arrays` strategy to accept an ``elements=kwargs`` +dict to pass through to :func:`~hypothesis.extra.numpy.from_dtype`. + +``arrays(floating_dtypes(), shape, elements={"min_value": -10, "max_value": 10})`` +is a particularly useful pattern, as it allows for any floating dtype without +triggering the roundoff warning for smaller types or sacrificing variety for +larger types (:issue:`2552`). diff --git a/hypothesis-python/src/hypothesis/extra/numpy.py b/hypothesis-python/src/hypothesis/extra/numpy.py index f5fc4c2bc1..194564210c 100644 --- a/hypothesis-python/src/hypothesis/extra/numpy.py +++ b/hypothesis-python/src/hypothesis/extra/numpy.py @@ -15,7 +15,7 @@ import math import re -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, Mapping, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np @@ -41,49 +41,97 @@ class BroadcastableShapes(NamedTuple): @st.defines_strategy(force_reusable_values=True) -def from_dtype(dtype: np.dtype) -> st.SearchStrategy[Any]: - """Creates a strategy which can generate any value of the given dtype.""" +def from_dtype( + dtype: np.dtype, + *, + alphabet: Optional[st.SearchStrategy[str]] = None, + min_size: int = 0, + max_size: Optional[int] = None, + min_value: Union[int, float, None] = None, + max_value: Union[int, float, None] = None, + allow_nan: Optional[bool] = None, + allow_infinity: Optional[bool] = None, + exclude_min: Optional[bool] = None, + exclude_max: Optional[bool] = None, +) -> st.SearchStrategy[Any]: + """Creates a strategy which can generate any value of the given dtype. + + Compatible ``**kwargs`` are passed to the inferred strategy function for + integers, floats, and strings. This allows you to customise the min and max + values, control the length or contents of strings, or exclude non-finite + numbers. This is particularly useful when kwargs are passed through from + :func:`arrays` which allow a variety of numeric dtypes, as it seamlessly + handles the ``width`` or representable bounds for you. See :issue:`2552` + for more detail. + """ check_type(np.dtype, dtype, "dtype") + kwargs = {k: v for k, v in locals().items() if k != "dtype" and v is not None} + # Compound datatypes, eg 'f4,f4,f4' if dtype.names is not None: # mapping np.void.type over a strategy is nonsense, so return now. - return st.tuples(*[from_dtype(dtype.fields[name][0]) for name in dtype.names]) + subs = [from_dtype(dtype.fields[name][0], **kwargs) for name in dtype.names] + return st.tuples(*subs) # Subarray datatypes, eg '(2, 3)i4' if dtype.subdtype is not None: subtype, shape = dtype.subdtype - return arrays(subtype, shape) + return arrays(subtype, shape, elements=kwargs) + + def compat_kw(*args, **kw): + """Update default args to the strategy with user-supplied keyword args.""" + assert {"min_value", "max_value", "max_size"}.issuperset(kw) + for key in set(kwargs).intersection(kw): + msg = f"dtype {dtype!r} requires {key}={kwargs[key]!r} to be %s {kw[key]!r}" + if kw[key] is not None: + if key.startswith("min_") and kw[key] > kwargs[key]: + raise InvalidArgument(msg % ("at least",)) + elif key.startswith("max_") and kw[key] < kwargs[key]: + raise InvalidArgument(msg % ("at most",)) + kw.update({k: v for k, v in kwargs.items() if k in args or k in kw}) + return kw # Scalar datatypes if dtype.kind == "b": result = st.booleans() # type: SearchStrategy[Any] elif dtype.kind == "f": - if dtype.itemsize == 2: - result = st.floats(width=16) - elif dtype.itemsize == 4: - result = st.floats(width=32) - else: - result = st.floats() + result = st.floats( + width=8 * dtype.itemsize, + **compat_kw( + "min_value", + "max_value", + "allow_nan", + "allow_infinity", + "exclude_min", + "exclude_max", + ), + ) elif dtype.kind == "c": + # If anyone wants to add a `width` argument to `complex_numbers()`, we would + # accept a pull request and add passthrough support for magnitude bounds, + # but it's a low priority otherwise. if dtype.itemsize == 8: - float32 = st.floats(width=32) + float32 = st.floats(width=32, **compat_kw("allow_nan", "allow_infinity")) result = st.builds(complex, float32, float32) else: - result = st.complex_numbers() + result = st.complex_numbers(**compat_kw("allow_nan", "allow_infinity")) elif dtype.kind in ("S", "a"): # Numpy strings are null-terminated; only allow round-trippable values. # `itemsize == 0` means 'fixed length determined at array creation' - result = st.binary(max_size=dtype.itemsize or None).filter( + max_size = dtype.itemsize or None + result = st.binary(**compat_kw("min_size", max_size=max_size)).filter( lambda b: b[-1:] != b"\0" ) elif dtype.kind == "u": - result = st.integers(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1) + kw = compat_kw(min_value=0, max_value=2 ** (8 * dtype.itemsize) - 1) + result = st.integers(**kw) elif dtype.kind == "i": overflow = 2 ** (8 * dtype.itemsize - 1) - result = st.integers(min_value=-overflow, max_value=overflow - 1) + result = st.integers(**compat_kw(min_value=-overflow, max_value=overflow - 1)) elif dtype.kind == "U": # Encoded in UTF-32 (four bytes/codepoint) and null-terminated - result = st.text(max_size=(dtype.itemsize or 0) // 4 or None).filter( + max_size = (dtype.itemsize or 0) // 4 or None + result = st.text(**compat_kw("alphabet", "min_size", max_size=max_size)).filter( lambda b: b[-1:] != "\0" ) elif dtype.kind in ("m", "M"): @@ -303,7 +351,7 @@ def arrays( dtype: Any, shape: Union[int, Shape, st.SearchStrategy[Shape]], *, - elements: Optional[st.SearchStrategy[Any]] = None, + elements: Optional[Union[SearchStrategy, Mapping[str, Any]]] = None, fill: Optional[st.SearchStrategy[Any]] = None, unique: bool = False, ) -> st.SearchStrategy[np.ndarray]: @@ -317,8 +365,7 @@ def arrays( * ``elements`` is a strategy for generating values to put in the array. If it is None a suitable value will be inferred based on the dtype, which may give any legal value (including eg ``NaN`` for floats). - If you have more specific requirements, you should supply your own - elements strategy. + If a mapping, it will be passed as ``**kwargs`` to ``from_dtype()`` * ``fill`` is a strategy that may be used to generate a single background value for the array. If None, a suitable default will be inferred based on the other arguments. If set to @@ -391,7 +438,7 @@ def arrays( ) # From here on, we're only dealing with values and it's relatively simple. dtype = np.dtype(dtype) - if elements is None: + if elements is None or isinstance(elements, Mapping): if dtype.kind in ("m", "M") and "[" not in dtype.str: # For datetime and timedelta dtypes, we have a tricky situation - # because they *may or may not* specify a unit as part of the dtype. @@ -402,7 +449,7 @@ def arrays( .map((dtype.str + "[{}]").format) .flatmap(lambda d: arrays(d, shape=shape, fill=fill, unique=unique)) ) - elements = from_dtype(dtype) + elements = from_dtype(dtype, **(elements or {})) check_strategy(elements, "elements") if isinstance(shape, int): shape = (shape,) diff --git a/hypothesis-python/tests/numpy/test_argument_validation.py b/hypothesis-python/tests/numpy/test_argument_validation.py index dd1eadcfe2..2440a13e5d 100644 --- a/hypothesis-python/tests/numpy/test_argument_validation.py +++ b/hypothesis-python/tests/numpy/test_argument_validation.py @@ -65,6 +65,12 @@ def e(a, **kwargs): e(nps.from_dtype, dtype=float), e(nps.from_dtype, dtype=numpy.int8), e(nps.from_dtype, dtype=1), + e(nps.from_dtype, dtype=numpy.dtype("uint8"), min_value=-999), + e(nps.from_dtype, dtype=numpy.dtype("uint8"), max_value=999), + e(nps.from_dtype, dtype=numpy.dtype("int8"), min_value=-999), + e(nps.from_dtype, dtype=numpy.dtype("int8"), max_value=999), + e(nps.from_dtype, dtype=numpy.dtype("S4"), max_size=5), + e(nps.from_dtype, dtype=numpy.dtype("U4"), max_size=5), e(nps.valid_tuple_axes, ndim=-1), e(nps.valid_tuple_axes, ndim=2, min_size=-1), e(nps.valid_tuple_axes, ndim=2, min_size=3, max_size=10), diff --git a/hypothesis-python/tests/numpy/test_from_dtype.py b/hypothesis-python/tests/numpy/test_from_dtype.py index 6d3df28cfd..667313738d 100644 --- a/hypothesis-python/tests/numpy/test_from_dtype.py +++ b/hypothesis-python/tests/numpy/test_from_dtype.py @@ -15,12 +15,12 @@ import numpy as np import pytest -from tests.common.debug import find_any from hypothesis import assume, given, settings, strategies as st from hypothesis.errors import InvalidArgument from hypothesis.extra import numpy as nps from hypothesis.strategies._internal import SearchStrategy +from tests.common.debug import find_any STANDARD_TYPES = [ np.dtype(t) @@ -180,3 +180,44 @@ def test_arrays_selects_consistent_time_unit(data, dtype_str): def test_arrays_gives_useful_error_on_inconsistent_time_unit(): with pytest.raises(InvalidArgument, match="mismatch of time units in dtypes"): nps.arrays("m8[Y]", 10, elements=nps.from_dtype(np.dtype("m8[D]"))).example() + + +@pytest.mark.parametrize( + "dtype, kwargs, pred", + [ + # Floating point: bounds, exclusive bounds, and excluding nonfinites + (float, {"min_value": 1, "max_value": 2}, lambda x: 1 <= x <= 2), + ( + float, + {"min_value": 1, "max_value": 2, "exclude_min": True, "exclude_max": True}, + lambda x: 1 < x < 2, + ), + (float, {"allow_nan": False}, lambda x: not np.isnan(x)), + (float, {"allow_infinity": False}, lambda x: not np.isinf(x)), + (float, {"allow_nan": False, "allow_infinity": False}, np.isfinite), + (complex, {"allow_nan": False}, lambda x: not np.isnan(x)), + (complex, {"allow_infinity": False}, lambda x: not np.isinf(x)), + (complex, {"allow_nan": False, "allow_infinity": False}, np.isfinite), + # Integer bounds, limited to the representable range + ("int8", {"min_value": -1, "max_value": 1}, lambda x: -1 <= x <= 1), + ("uint8", {"min_value": 1, "max_value": 2}, lambda x: 1 <= x <= 2), + # String arguments, bounding size and unicode alphabet + ("S", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("S4", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("U", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("U4", {"min_size": 1, "max_size": 2}, lambda x: 1 <= len(x) <= 2), + ("U", {"alphabet": "abc"}, lambda x: set(x).issubset("abc")), + ], +) +@given(data=st.data()) +def test_from_dtype_with_kwargs(data, dtype, kwargs, pred): + value = data.draw(nps.from_dtype(np.dtype(dtype), **kwargs)) + assert pred(value) + + +@given(nps.from_dtype(np.dtype("U20,uint8,float32"), min_size=1, allow_nan=False)) +def test_customize_structured_dtypes(x): + name, age, score = x + assert len(name) >= 1 + assert 0 <= age <= 255 + assert not np.isnan(score) diff --git a/hypothesis-python/tests/numpy/test_gen_data.py b/hypothesis-python/tests/numpy/test_gen_data.py index 6381f6ec4f..eda7797081 100644 --- a/hypothesis-python/tests/numpy/test_gen_data.py +++ b/hypothesis-python/tests/numpy/test_gen_data.py @@ -367,6 +367,29 @@ def test_inferred_floats_do_not_overflow(arr): pass +@given(nps.arrays(dtype="float16", shape=10, elements={"min_value": 0, "max_value": 1})) +def test_inferred_floats_can_be_constrained_at_low_width(arr): + assert (arr >= 0).all() + assert (arr <= 1).all() + + +@given( + nps.arrays( + dtype="float16", + shape=10, + elements={ + "min_value": 0, + "max_value": 1, + "exclude_min": True, + "exclude_max": True, + }, + ) +) +def test_inferred_floats_can_be_constrained_at_low_width_excluding_endpoints(arr): + assert (arr > 0).all() + assert (arr < 1).all() + + @given( nps.arrays( dtype="float16",