diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f85e4833e13c..c733357a8b6f 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -310,7 +310,15 @@ def coerce_to_array(x: Any, dtype: DTypeLike | None = None) -> np.ndarray: return np.asarray(x, dtype) iinfo = ml_dtypes.iinfo -finfo = ml_dtypes.finfo + +# We cast finfo attributes from np.float to built-in python floats +# for Array API compliance +class finfo(ml_dtypes.finfo): + def __setattr__(self, name, value): + if isinstance(value, np.floating): + value = float(value) + object.__setattr__(self, name, value) + def _issubclass(a: Any, b: Any) -> bool: """Determines if ``a`` is a subclass of ``b``. diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 248c1c6dd0fe..f9799255ecf6 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -64,8 +64,6 @@ class FInfo(NamedTuple): smallest_normal: float dtype: jnp.dtype -# TODO(micky774): Update jax.numpy.finfo so that its attributes are python -# floats def finfo(type, /) -> FInfo: info = jnp.finfo(type) return FInfo(