Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jnp.isdtype function, following np.isdtype in NumPy 2.0 #19400

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ namespace; they are listed below.
isclose
iscomplex
iscomplexobj
isdtype
isfinite
isin
isinf
Expand Down
62 changes: 60 additions & 2 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,32 +376,40 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool:
# Enumeration of all valid JAX types in order.
_weak_types: list[JAXType] = [int, float, complex]
_bool_types: list[JAXType] = [np.dtype(bool)]
_signed_types: list[JAXType]
_unsigned_types: list[JAXType]
_int_types: list[JAXType]
if int4 is not None:
_int_types = [
_unsigned_types = [
np.dtype(uint4),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype(int4),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
else:
_int_types = [
_unsigned_types = [
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]

_int_types = _unsigned_types + _signed_types

_float_types: list[JAXType] = [
*_custom_float_dtypes,
np.dtype('float16'),
Expand All @@ -415,6 +423,56 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool:
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}


_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
}


def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType]) -> bool:
"""Returns a boolean indicating whether a provided dtype is of a specified kind.

Args:
dtype : the input dtype
kind : the data type kind.
If ``kind`` is dtype-like, return ``dtype = kind``.
If ``kind`` is a string, then return True if the dtype is in the specified category:

- ``'bool'``: ``{bool}``
- ``'signed integer'``: ``{int4, int8, int16, int32, int64}``
- ``'unsigned integer'``: ``{uint4, uint8, uint16, uint32, uint64}``
- ``'integral'``: shorthand for ``('signed integer', 'unsigned integer')``
- ``'real floating'``: ``{float8_*, float16, bfloat16, float32, float64}``
- ``'complex floating'``: ``{complex64, complex128}``
- ``'numeric'``: shorthand for ``('integral', 'real floating', 'complex floating')``

If ``kind`` is a tuple, then return True if dtype matches any entry of the tuple.

Returns:
True or False
"""
the_dtype = np.dtype(dtype)
kind_tuple: tuple[DType | str] = kind if isinstance(kind, tuple) else (kind,)
options: set[DType] = set()
for kind in kind_tuple:
if isinstance(kind, str):
if kind not in _dtype_kinds:
raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}")
options.update(_dtype_kinds[kind])
elif isinstance(kind, np.dtype):
options.add(kind)
else:
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
return the_dtype in options


def _jax_type(dtype: DType, weak_type: bool) -> JAXType:
"""Return the jax type for a dtype and weak type."""
if weak_type:
Expand Down
27 changes: 1 addition & 26 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,33 +176,8 @@ def iinfo(type, /) -> IInfo:
return IInfo(bits=info.bits, max=info.max, min=info.min, dtype=jnp.dtype(type))


_dtype_kinds = {
'bool': {bool},
'signed integer': {int8, int16, int32, int64},
'unsigned integer': {uint8, uint16, uint32, uint64},
'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64},
'real floating': {float32, float64},
'complex floating': {complex64, complex128},
'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64,
float32, float64, complex64, complex128},
}

def isdtype(dtype, kind):
if not _is_valid_dtype(dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
if isinstance(kind, tuple):
return any(_isdtype(dtype, k) for k in kind)
return _isdtype(dtype, kind)

def _isdtype(dtype, kind):
if isinstance(kind, jnp.dtype):
return dtype == kind
elif isinstance(kind, str):
if kind not in _dtype_kinds:
raise ValueError(f"Unrecognized {kind=!r}")
return dtype in _dtype_kinds[kind]
else:
raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.")
return jax.numpy.isdtype(dtype, kind)


def result_type(*arrays_and_dtypes):
Expand Down
4 changes: 4 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

from jax._src.basearray import Array as ndarray

from jax._src.dtypes import (
isdtype as isdtype,
)

from jax._src.numpy.lax_numpy import (
ComplexWarning as ComplexWarning,
allclose as allclose,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ...,
atol: ArrayLike = ..., equal_nan: bool = ...) -> Array: ...
def iscomplex(m: ArrayLike) -> Array: ...
def iscomplexobj(x: Any) -> bool: ...
def isdtype(dtype: DTypeLike, kind: Union[DType, str, tuple[Union[DType, str], ...]]) -> bool: ...
def isfinite(x: ArrayLike, /) -> Array: ...
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: bool = ..., invert: bool = ...) -> Array: ...
Expand Down
21 changes: 21 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
jnp.complex64, jnp.complex128]

dtype_kinds = {
'bool': bool_dtypes,
'signed integer': signed_dtypes,
'unsigned integer': unsigned_dtypes,
'integral': signed_dtypes + unsigned_dtypes,
'real floating': float_dtypes,
'complex floating': complex_dtypes,
'numeric': signed_dtypes + unsigned_dtypes + float_dtypes + complex_dtypes,
}

python_scalar_types = [bool, int, float, complex]

_EXPECTED_CANONICALIZE_X64 = {value: value for value in scalar_types}
Expand Down Expand Up @@ -325,6 +335,17 @@ def testIsSubdtypeInt4(self, dtype):
self.assertFalse(dtypes.issubdtype(dt, np.int64))
self.assertFalse(dtypes.issubdtype(np.generic, dt))

@jtu.sample_product(
dtype=all_dtypes,
kind=(*dtype_kinds, *all_dtypes)
)
def testIsDtype(self, dtype, kind):
if isinstance(kind, np.dtype):
expected = (dtype == kind)
else:
expected = (dtype in dtype_kinds[kind])
self.assertEqual(expected, dtypes.isdtype(dtype, kind))

def testArrayCasts(self):
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
a = np.array([1, 2.5, -3.7])
Expand Down
15 changes: 15 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5460,6 +5460,21 @@ def test_error_hint(self, fn):
r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"):
fn(2, 3)

@jtu.sample_product(
dtype=jtu.dtypes.all,
kind=['bool', 'signed integer', 'unsigned integer', 'integral',
'real floating', 'complex floating', 'numeric']
)
def test_isdtype(self, dtype, kind):
# Full tests also in dtypes_test.py; here we just compare against numpy
jax_result = jnp.isdtype(dtype, kind)
if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16:
# just a smoke test
self.assertIsInstance(jax_result, bool)
else:
numpy_result = np.isdtype(dtype, kind)
self.assertEqual(jax_result, numpy_result)


# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
Expand Down