From fbf7492a2cd3a0cafe68e41fc4be98f3ed627540 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 17 Jan 2024 12:14:55 -0800 Subject: [PATCH] Add jnp.isdtype function, following np.isdtype in NumPy 2.0 --- docs/jax.numpy.rst | 1 + jax/_src/dtypes.py | 62 ++++++++++++++++++- .../array_api/_data_type_functions.py | 27 +------- jax/numpy/__init__.py | 4 ++ jax/numpy/__init__.pyi | 1 + tests/dtypes_test.py | 21 +++++++ tests/lax_numpy_test.py | 15 +++++ 7 files changed, 103 insertions(+), 28 deletions(-) diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index e47858a7c584..cdc557477c4e 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -235,6 +235,7 @@ namespace; they are listed below. isclose iscomplex iscomplexobj + isdtype isfinite isin isinf diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 48214c7a41f7..ebe3135c1f05 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -376,14 +376,18 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool: # Enumeration of all valid JAX types in order. _weak_types: list[JAXType] = [int, float, complex] _bool_types: list[JAXType] = [np.dtype(bool)] +_signed_types: list[JAXType] +_unsigned_types: list[JAXType] _int_types: list[JAXType] if int4 is not None: - _int_types = [ + _unsigned_types = [ np.dtype(uint4), np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64'), + ] + _signed_types = [ np.dtype(int4), np.dtype('int8'), np.dtype('int16'), @@ -391,17 +395,21 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool: np.dtype('int64'), ] else: - _int_types = [ + _unsigned_types = [ np.dtype('uint8'), np.dtype('uint16'), np.dtype('uint32'), np.dtype('uint64'), + ] + _signed_types = [ np.dtype('int8'), np.dtype('int16'), np.dtype('int32'), np.dtype('int64'), ] +_int_types = _unsigned_types + _signed_types + _float_types: list[JAXType] = [ *_custom_float_dtypes, np.dtype('float16'), @@ -415,6 +423,56 @@ def issubdtype(a: DTypeLike | None, b: DTypeLike | None) -> bool: _jax_types = _bool_types + _int_types + _float_types + _complex_types _jax_dtype_set = {float0, *_bool_types, *_int_types, *_float_types, *_complex_types} + +_dtype_kinds: dict[str, set] = { + 'bool': {*_bool_types}, + 'signed integer': {*_signed_types}, + 'unsigned integer': {*_unsigned_types}, + 'integral': {*_signed_types, *_unsigned_types}, + 'real floating': {*_float_types}, + 'complex floating': {*_complex_types}, + 'numeric': {*_signed_types, *_unsigned_types, *_float_types, *_complex_types}, +} + + +def isdtype(dtype: DTypeLike, kind: str | DType | tuple[str | DType]) -> bool: + """Returns a boolean indicating whether a provided dtype is of a specified kind. + + Args: + dtype : the input dtype + kind : the data type kind. + If ``kind`` is dtype-like, return ``dtype = kind``. + If ``kind`` is a string, then return True if the dtype is in the specified category: + + - ``'bool'``: ``{bool}`` + - ``'signed integer'``: ``{int4, int8, int16, int32, int64}`` + - ``'unsigned integer'``: ``{uint4, uint8, uint16, uint32, uint64}`` + - ``'integral'``: shorthand for ``('signed integer', 'unsigned integer')`` + - ``'real floating'``: ``{float8_*, float16, bfloat16, float32, float64}`` + - ``'complex floating'``: ``{complex64, complex128}`` + - ``'numeric'``: shorthand for ``('integral', 'real floating', 'complex floating')`` + + If ``kind`` is a tuple, then return True if dtype matches any entry of the tuple. + + Returns: + True or False + """ + the_dtype = np.dtype(dtype) + kind_tuple: tuple[DType | str] = kind if isinstance(kind, tuple) else (kind,) + options: set[DType] = set() + for kind in kind_tuple: + if isinstance(kind, str): + if kind not in _dtype_kinds: + raise ValueError(f"Unrecognized {kind=} expected one of {list(_dtype_kinds.keys())}") + options.update(_dtype_kinds[kind]) + elif isinstance(kind, np.dtype): + options.add(kind) + else: + # TODO(jakevdp): should we handle scalar types or ScalarMeta here? + raise TypeError(f"Expected kind to be a dtype, string, or tuple; got {kind=}") + return the_dtype in options + + def _jax_type(dtype: DType, weak_type: bool) -> JAXType: """Return the jax type for a dtype and weak type.""" if weak_type: diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 7403136cfff1..d2bb032b85ab 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -176,33 +176,8 @@ def iinfo(type, /) -> IInfo: return IInfo(bits=info.bits, max=info.max, min=info.min, dtype=jnp.dtype(type)) -_dtype_kinds = { - 'bool': {bool}, - 'signed integer': {int8, int16, int32, int64}, - 'unsigned integer': {uint8, uint16, uint32, uint64}, - 'integral': {int8, int16, int32, int64, uint8, uint16, uint32, uint64}, - 'real floating': {float32, float64}, - 'complex floating': {complex64, complex128}, - 'numeric': {int8, int16, int32, int64, uint8, uint16, uint32, uint64, - float32, float64, complex64, complex128}, -} - def isdtype(dtype, kind): - if not _is_valid_dtype(dtype): - raise ValueError(f"{dtype} is not a valid dtype.") - if isinstance(kind, tuple): - return any(_isdtype(dtype, k) for k in kind) - return _isdtype(dtype, kind) - -def _isdtype(dtype, kind): - if isinstance(kind, jnp.dtype): - return dtype == kind - elif isinstance(kind, str): - if kind not in _dtype_kinds: - raise ValueError(f"Unrecognized {kind=!r}") - return dtype in _dtype_kinds[kind] - else: - raise ValueError(f"Invalid kind with {kind}. Expected string or dtype.") + return jax.numpy.isdtype(dtype, kind) def result_type(*arrays_and_dtypes): diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 2bdb4724b527..032eb8eb584e 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -20,6 +20,10 @@ from jax._src.basearray import Array as ndarray +from jax._src.dtypes import ( + isdtype as isdtype, +) + from jax._src.numpy.lax_numpy import ( ComplexWarning as ComplexWarning, allclose as allclose, diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 92848a8efb69..bfb0561240da 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -448,6 +448,7 @@ def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: bool = ...) -> Array: ... def iscomplex(m: ArrayLike) -> Array: ... def iscomplexobj(x: Any) -> bool: ... +def isdtype(dtype: DTypeLike, kind: Union[DType, str, tuple[Union[DType, str], ...]]) -> bool: ... def isfinite(x: ArrayLike, /) -> Array: ... def isin(element: ArrayLike, test_elements: ArrayLike, assume_unique: bool = ..., invert: bool = ...) -> Array: ... diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index b6fe23f58ff1..4f563876cb91 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -71,6 +71,16 @@ jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64, jnp.complex64, jnp.complex128] +dtype_kinds = { + 'bool': bool_dtypes, + 'signed integer': signed_dtypes, + 'unsigned integer': unsigned_dtypes, + 'integral': signed_dtypes + unsigned_dtypes, + 'real floating': float_dtypes, + 'complex floating': complex_dtypes, + 'numeric': signed_dtypes + unsigned_dtypes + float_dtypes + complex_dtypes, +} + python_scalar_types = [bool, int, float, complex] _EXPECTED_CANONICALIZE_X64 = {value: value for value in scalar_types} @@ -325,6 +335,17 @@ def testIsSubdtypeInt4(self, dtype): self.assertFalse(dtypes.issubdtype(dt, np.int64)) self.assertFalse(dtypes.issubdtype(np.generic, dt)) + @jtu.sample_product( + dtype=all_dtypes, + kind=(*dtype_kinds, *all_dtypes) + ) + def testIsDtype(self, dtype, kind): + if isinstance(kind, np.dtype): + expected = (dtype == kind) + else: + expected = (dtype in dtype_kinds[kind]) + self.assertEqual(expected, dtypes.isdtype(dtype, kind)) + def testArrayCasts(self): for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]: a = np.array([1, 2.5, -3.7]) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 23dc2d84fec7..9e4f1811c7e9 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -5460,6 +5460,21 @@ def test_error_hint(self, fn): r"when you meant `jax\.numpy\..*?\(\(2, 3\)\)`"): fn(2, 3) + @jtu.sample_product( + dtype=jtu.dtypes.all, + kind=['bool', 'signed integer', 'unsigned integer', 'integral', + 'real floating', 'complex floating', 'numeric'] + ) + def test_isdtype(self, dtype, kind): + # Full tests also in dtypes_test.py; here we just compare against numpy + jax_result = jnp.isdtype(dtype, kind) + if jtu.numpy_version() < (2, 0, 0) or dtype == dtypes.bfloat16: + # just a smoke test + self.assertIsInstance(jax_result, bool) + else: + numpy_result = np.isdtype(dtype, kind) + self.assertEqual(jax_result, numpy_result) + # Most grad tests are at the lax level (see lax_test.py), but we add some here # as needed for e.g. particular compound ops of interest.