Skip to content

Commit

Permalink
Merge pull request #19400 from jakevdp:jnp-isdtype
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599265256
  • Loading branch information
jax authors committed Jan 17, 2024
2 parents 3e80670 + fbf7492 commit df28ee7
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 28 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ namespace; they are listed below.
isclose
iscomplex
iscomplexobj
isdtype
isfinite
isin
isinf
Expand Down
62 changes: 60 additions & 2 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,32 +376,40 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool:
# Enumeration of all valid JAX types in order.
_weak_types: list[JAXType] = [int, float, complex]
_bool_types: list[JAXType] = [np.dtype(bool)]
_signed_types: list[JAXType]
_unsigned_types: list[JAXType]
_int_types: list[JAXType]
if int4 is not None:
_int_types = [
_unsigned_types = [
np.dtype(uint4),
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype(int4),
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]
else:
_int_types = [
_unsigned_types = [
np.dtype('uint8'),
np.dtype('uint16'),
np.dtype('uint32'),
np.dtype('uint64'),
]
_signed_types = [
np.dtype('int8'),
np.dtype('int16'),
np.dtype('int32'),
np.dtype('int64'),
]

_int_types = _unsigned_types + _signed_types

_float_types: list[JAXType] = [
*_custom_float_dtypes,
np.dtype('float16'),
Expand All @@ -415,6 +423,56 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool:
_jax_types = _bool_types + _int_types + _float_types + _complex_types
_jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types}


_dtype_kinds: dict[str, set] = {
'bool': {*_bool_types},
'signed integer': {*_signed_types},
'unsigned integer': {*_unsigned_types},
'integral': {*_signed_types, *_unsigned_types},
'real floating': {*_float_types},
'complex floating': {*_complex_types},
'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types},
}


def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType]) -> bool:
"""Returns a boolean indicating whether a provided dtype is of a specified kind.
Args:
dtype : the input dtype
kind : the data type kind.
If ``kind`` is dtype-like, return ``dtype = kind``.
If ``kind`` is a string, then return True if the dtype is in the specified category:
- ``'bool'``: ``{bool}``
- ``'signed integer'``: ``{int4, int8, int16, int32, int64}``
- ``'unsigned integer'``: ``{uint4, uint8, uint16, uint32, uint64}``
- ``'integral'``: shorthand for ``('signed integer', 'unsigned integer')``
- ``'real floating'``: ``{float8_*, float16, bfloat16, float32, float64}``
- ``'complex floating'``: ``{complex64, complex128}``
- ``'numeric'``: shorthand for ``('integral', 'real floating', 'complex floating')``
If ``kind`` is a tuple, then return True if dtype matches any entry of the tuple.
Returns:
True or False
"""
the_dtype = np.dtype(dtype)
kind_tuple: tuple[DType | str] = kind if isinstance(kind, tuple) else (kind,)
options: set[DType] = set()
for kind in kind_tuple:
if isinstance(kind, str):
if kind not in _dtype_kinds:
raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}")
options.update(_dtype_kinds[kind])
elif isinstance(kind, np.dtype):
options.add(kind)
else:
# TODO(jakevdp): should we handle scalar types or ScalarMeta here?
raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}")
return the_dtype in options


def _jax_type(dtype: DType, weak_type: bool) -> JAXType:
"""Return the jax type for a dtype and weak type."""
if weak_type:
Expand Down
27 changes: 1 addition & 26 deletions jax/experimental/array_api/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,33 +176,8 @@ def iinfo(type, /) -> IInfo:
return IInfo(bits=info.bits, max=info.max, min=info.min, dtype=jnp.dtype(type))


_dtype_kinds = {
'bool': {bool},
'signed integer': {int8, int16, int32, int64},
'unsigned integer': {uint8, uint16, uint32, uint64},
'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64},
'real floating': {float32, float64},
'complex floating': {complex64, complex128},
'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64,
float32, float64, complex64, complex128},
}

def isdtype(dtype, kind):
if not _is_valid_dtype(dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
if isinstance(kind, tuple):
return any(_isdtype(dtype, k) for k in kind)
return _isdtype(dtype, kind)

def _isdtype(dtype, kind):
if isinstance(kind, jnp.dtype):
return dtype == kind
elif isinstance(kind, str):
if kind not in _dtype_kinds:
raise ValueError(f"Unrecognized {kind=!r}")
return dtype in _dtype_kinds[kind]
else:
raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.")
return jax.numpy.isdtype(dtype, kind)


def result_type(*arrays_and_dtypes):
Expand Down
4 changes: 4 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

from jax._src.basearray import Array as ndarray

from jax._src.dtypes import (
isdtype as isdtype,
)

from jax._src.numpy.lax_numpy import (
ComplexWarning as ComplexWarning,
allclose as allclose,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ...,
atol: ArrayLike = ..., equal_nan: bool = ...) -> Array: ...
def iscomplex(m: ArrayLike) -> Array: ...
def iscomplexobj(x: Any) -> bool: ...
def isdtype(dtype: DTypeLike, kind: Union[DType, str, tuple[Union[DType, str], ...]]) -> bool: ...
def isfinite(x: ArrayLike, /) -> Array: ...
def isin(element: ArrayLike, test_elements: ArrayLike,
assume_unique: bool = ..., invert: bool = ...) -> Array: ...
Expand Down
21 changes: 21 additions & 0 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@
jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64,
jnp.complex64, jnp.complex128]

dtype_kinds = {
'bool': bool_dtypes,
'signed integer': signed_dtypes,
'unsigned integer': unsigned_dtypes,
'integral': signed_dtypes + unsigned_dtypes,
'real floating': float_dtypes,
'complex floating': complex_dtypes,
'numeric': signed_dtypes + unsigned_dtypes + float_dtypes + complex_dtypes,
}

python_scalar_types = [bool, int, float, complex]

_EXPECTED_CANONICALIZE_X64 = {value: value for value in scalar_types}
Expand Down Expand Up @@ -325,6 +335,17 @@ def testIsSubdtypeInt4(self, dtype):
self.assertFalse(dtypes.issubdtype(dt, np.int64))
self.assertFalse(dtypes.issubdtype(np.generic, dt))

@jtu.sample_product(
dtype=all_dtypes,
kind=(*dtype_kinds, *all_dtypes)
)
def testIsDtype(self, dtype, kind):
if isinstance(kind, np.dtype):
expected = (dtype == kind)
else:
expected = (dtype in dtype_kinds[kind])
self.assertEqual(expected, dtypes.isdtype(dtype, kind))

def testArrayCasts(self):
for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]:
a = np.array([1, 2.5, -3.7])
Expand Down
15 changes: 15 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5460,6 +5460,21 @@ def test_error_hint(self, fn):
r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"):
fn(2, 3)

@jtu.sample_product(
dtype=jtu.dtypes.all,
kind=['bool', 'signed integer', 'unsigned integer', 'integral',
'real floating', 'complex floating', 'numeric']
)
def test_isdtype(self, dtype, kind):
# Full tests also in dtypes_test.py; here we just compare against numpy
jax_result = jnp.isdtype(dtype, kind)
if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16:
# just a smoke test
self.assertIsInstance(jax_result, bool)
else:
numpy_result = np.isdtype(dtype, kind)
self.assertEqual(jax_result, numpy_result)


# Most grad tests are at the lax level (see lax_test.py), but we add some here
# as needed for e.g. particular compound ops of interest.
Expand Down

0 comments on commit df28ee7

Please sign in to comment.