Skip to content

Commit

Permalink
Merge pull request #3670 from jobh/numpy-parameterized-arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD authored Jun 9, 2023
2 parents fc566bf + 50d48a2 commit 54bf5dc
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 30 deletions.
6 changes: 6 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
RELEASE_TYPE: minor

:func:`~hypothesis.strategies.from_type` now handles numpy array types:
:obj:`np.typing.ArrayLike <numpy.typing.ArrayLike>`,
:obj:`np.typing.NDArray <numpy.typing.NDArray>`, and parameterized
versions including :class:`np.ndarray[shape, elem_type] <numpy.ndarray>`.
155 changes: 154 additions & 1 deletion hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

import importlib
import math
from typing import (
TYPE_CHECKING,
Expand All @@ -16,6 +17,7 @@
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)
Expand Down Expand Up @@ -44,11 +46,27 @@
from hypothesis.internal.reflection import proxies
from hypothesis.internal.validation import check_type
from hypothesis.strategies._internal.numbers import Real
from hypothesis.strategies._internal.strategies import T, check_strategy
from hypothesis.strategies._internal.strategies import Ex, T, check_strategy
from hypothesis.strategies._internal.utils import defines_strategy


def _try_import(mod_name: str, attr_name: str) -> Any:
assert "." not in attr_name
try:
mod = importlib.import_module(mod_name)
return getattr(mod, attr_name, None)
except ImportError:
return None


if TYPE_CHECKING:
from numpy.typing import DTypeLike, NDArray
else:
NDArray = _try_import("numpy.typing", "NDArray")

ArrayLike = _try_import("numpy.typing", "ArrayLike")
_NestedSequence = _try_import("numpy._typing._nested_sequence", "_NestedSequence")
_SupportsArray = _try_import("numpy._typing._array_like", "_SupportsArray")

__all__ = [
"BroadcastableShapes",
Expand Down Expand Up @@ -978,3 +996,138 @@ def array_for(index_shape, size):
return result_shape.flatmap(
lambda index_shape: st.tuples(*(array_for(index_shape, size) for size in shape))
)


def _unpack_generic(thing):
# get_origin and get_args fail on python<3.9 because (some of) the
# relevant types do not inherit from _GenericAlias. So just pick the
# value out directly.
real_thing = getattr(thing, "__origin__", None)
if real_thing is not None:
return (real_thing, getattr(thing, "__args__", ()))
else:
return (thing, ())


def _unpack_dtype(dtype):
dtype_args = getattr(dtype, "__args__", ())
if dtype_args:
assert len(dtype_args) == 1
if isinstance(dtype_args[0], TypeVar):
# numpy.dtype[+ScalarType]
assert dtype_args[0].__bound__ == np.generic
dtype = Any
else:
# plain dtype
dtype = dtype_args[0]
return dtype


def _dtype_and_shape_from_args(args):
if len(args) <= 1:
# Zero args: ndarray, _SupportsArray
# One arg: ndarray[type], _SupportsArray[type]
shape = Any
dtype = _unpack_dtype(args[0]) if args else Any
else:
# Two args: ndarray[shape, type], NDArray[*]
assert len(args) == 2
shape = args[0]
assert shape is Any
dtype = _unpack_dtype(args[1])
return (
scalar_dtypes() if dtype is Any else np.dtype(dtype),
array_shapes(max_dims=2) if shape is Any else shape,
)


def _from_type(thing: Type[Ex]) -> Optional[st.SearchStrategy[Ex]]:
"""Called by st.from_type to try to infer a strategy for thing using numpy.
If we can infer a numpy-specific strategy for thing, we return that; otherwise,
we return None.
"""

base_strats = st.one_of(
[
st.booleans(),
st.integers(),
st.floats(),
st.complex_numbers(),
st.text(),
st.binary(),
]
)
# don't mix strings and non-ascii bytestrings (ex: ['', b'\x80']). See
# https://github.com/numpy/numpy/issues/23899.
base_strats_ascii = st.one_of(
[
st.booleans(),
st.integers(),
st.floats(),
st.complex_numbers(),
st.text(),
st.binary().filter(bytes.isascii),
]
)

if thing == np.dtype:
# Note: Parameterized dtypes and DTypeLike are not supported.
return st.one_of(
scalar_dtypes(),
byte_string_dtypes(),
unicode_string_dtypes(),
array_dtypes(),
nested_dtypes(),
)

if thing == ArrayLike:
# We override the default type resolution to ensure the "coercible to
# array" contract is honoured. See
# https://github.com/HypothesisWorks/hypothesis/pull/3670#issuecomment-1578140422.
# The actual type is (as of np 1.24), with
# scalars:=[bool, int, float, complex, str, bytes]:
# Union[
# _SupportsArray,
# _NestedSequence[_SupportsArray],
# *scalars,
# _NestedSequence[Union[*scalars]]
# ]
return st.one_of(
# *scalars
base_strats,
# The two recursive strategies below cover the following cases:
# - _SupportsArray (using plain ndarrays)
# - _NestedSequence[Union[*scalars]] (but excluding non-ascii binary)
# - _NestedSequence[_SupportsArray] (but with a single leaf element
# . to avoid the issue of unequally sized leaves)
st.recursive(st.lists(base_strats_ascii), extend=st.tuples),
st.recursive(st.from_type(np.ndarray), extend=st.tuples),
)

if isinstance(thing, type) and issubclass(thing, np.generic):
dtype = np.dtype(thing)
return from_dtype(dtype) if dtype.kind not in "OV" else None

real_thing, args = _unpack_generic(thing)

if real_thing == _NestedSequence:
# We have to override the default resolution to ensure sequences are of
# equal length. Actually they are still not, if the arg specialization
# returns arbitrary-shaped sequences or arrays - hence the even more special
# resolution of ArrayLike, above.
assert len(args) <= 1
base_strat = st.from_type(args[0]) if args else base_strats
return st.one_of(
st.lists(base_strat),
st.recursive(st.tuples(), st.tuples),
st.recursive(st.tuples(base_strat), st.tuples),
st.recursive(st.tuples(base_strat, base_strat), st.tuples),
)

if real_thing in [np.ndarray, _SupportsArray]:
dtype, shape = _dtype_and_shape_from_args(args)
return arrays(dtype, shape)

# We didn't find a type to resolve, continue
return None
19 changes: 9 additions & 10 deletions hypothesis-python/src/hypothesis/strategies/_internal/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,14 @@ def from_type_guarded(thing):
finally:
recurse_guard.pop()

# Let registered extra modules handle their own recognized types first, before
# e.g. Unions are resolved
if thing not in types._global_type_lookup:
for module, resolver in types._global_extra_lookup.items():
if module in sys.modules:
strat = resolver(thing)
if strat is not None:
return strat
if not isinstance(thing, type):
if types.is_a_new_type(thing):
# Check if we have an explicitly registered strategy for this thing,
Expand Down Expand Up @@ -1187,6 +1195,7 @@ def from_type_guarded(thing):
# We need to work with their type instead.
if isinstance(thing, TypeVar) and type(thing) in types._global_type_lookup:
return as_strategy(types._global_type_lookup[type(thing)], thing)

# If there's no explicitly registered strategy, maybe a subtype of thing
# is registered - if so, we can resolve it to the subclass strategy.
# We'll start by checking if thing is from from the typing module,
Expand Down Expand Up @@ -1215,16 +1224,6 @@ def from_type_guarded(thing):
# may be able to fall back on type annotations.
if issubclass(thing, enum.Enum):
return sampled_from(thing)
# Handle numpy types. If numpy is not imported, the type cannot be numpy related.
if "numpy" in sys.modules:
import numpy as np

if issubclass(thing, np.generic):
dtype = np.dtype(thing)
if dtype.kind not in "OV":
from hypothesis.extra.numpy import from_dtype

return from_dtype(dtype)
# Finally, try to build an instance by calling the type object. Unlike builds(),
# this block *does* try to infer strategies for arguments with default values.
# That's because of the semantic different; builds() -> "call this with ..."
Expand Down
31 changes: 20 additions & 11 deletions hypothesis-python/src/hypothesis/strategies/_internal/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,17 +594,6 @@ def _networks(bits):
_global_type_lookup[os._Environ] = st.just(os.environ)


try: # pragma: no cover
import numpy as np

from hypothesis.extra.numpy import array_dtypes, array_shapes, arrays, scalar_dtypes

_global_type_lookup[np.dtype] = array_dtypes()
_global_type_lookup[np.ndarray] = arrays(scalar_dtypes(), array_shapes(max_dims=2))
except ImportError:
pass


_global_type_lookup.update(
{
# Note: while ByteString notionally also represents the bytearray and
Expand Down Expand Up @@ -667,6 +656,26 @@ def _networks(bits):
_global_type_lookup[typing.SupportsIndex] = st.integers() | st.booleans()


# The "extra" lookups define a callable that either resolves to a strategy for
# this narrowly extra-specific type, or returns None to proceed with normal
# type resolution. The callable will only be called if the module is
# installed. To avoid the performance hit of importing anything here, we defer
# it until the method is called the first time, at which point we replace the
# entry in the lookup table with the direct call.
def _from_numpy_type(thing: typing.Type) -> typing.Optional[st.SearchStrategy]:
from hypothesis.extra.numpy import _from_type

_global_extra_lookup["numpy"] = _from_type
return _from_type(thing)


_global_extra_lookup: typing.Dict[
str, typing.Callable[[typing.Type], typing.Optional[st.SearchStrategy]]
] = {
"numpy": _from_numpy_type,
}


def register(type_, fallback=None, *, module=typing):
if isinstance(type_, str):
# Use the name of generic types which are not available on all
Expand Down
8 changes: 0 additions & 8 deletions hypothesis-python/tests/numpy/test_from_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from hypothesis.errors import InvalidArgument
from hypothesis.extra import numpy as nps
from hypothesis.internal.floats import width_smallest_normals
from hypothesis.strategies import from_type
from hypothesis.strategies._internal import SearchStrategy

from tests.common.debug import assert_no_examples, find_any
Expand Down Expand Up @@ -284,10 +283,3 @@ def condition(n):
find_any(strat, condition)
else:
assert_no_examples(strat, condition)


@pytest.mark.parametrize("dtype", STANDARD_TYPES)
def test_resolves_and_varies_numpy_type(dtype):
# Check that we find an instance that is not equal to the default
x = find_any(from_type(dtype.type), lambda x: x != type(x)())
assert isinstance(x, dtype.type)
Loading

0 comments on commit 54bf5dc

Please sign in to comment.