From 45f0e9ad683dc1139f8126ea123d2b0989ee8697 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Sat, 5 Oct 2024 07:11:44 -0700 Subject: [PATCH] Simplify definition of jnp.isscalar The new semantics are to return True for any array-like object with zero dimensions. Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice. PiperOrigin-RevId: 682656411 --- CHANGELOG.md | 5 ++ jax/_src/numpy/lax_numpy.py | 98 +++++++++++++++++++++++++++++++++++-- tests/lax_numpy_test.py | 8 +++ 3 files changed, 107 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 294689f5b541..fd4a71aeeb50 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## jax 0.4.35 +* Breaking Changes + * {func}`jax.numpy.isscalar` now returns True for any array-like object with + zero dimensions. Previously it only returned True for zero-dimensional + array-like objects with a weak dtype. + ## jax 0.4.34 (October 4, 2023) * New Functionality diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 0dfed07f4127..491d2bc74aaa 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -482,11 +482,101 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool: """ return dtypes.issubdtype(arg1, arg2) -@util.implements(np.isscalar) + def isscalar(element: Any) -> bool: - if hasattr(element, '__jax_array__'): - element = element.__jax_array__() - return dtypes.is_python_scalar(element) or np.isscalar(element) + """Return True if the input is a scalar. + + JAX implementation of :func:`numpy.isscalar`. JAX's implementation differs + from NumPy's in that it considers zero-dimensional arrays to be scalars; see + the *Note* below for more details. + + Args: + element: input object to check; any type is valid input. + + Returns: + True if ``element`` is a scalar value or an array-like object with zero + dimensions, False otherwise. + + Note: + JAX and NumPy differ in their representation of scalar values. NumPy has + special scalar objects (e.g. ``np.int32(0)``) which are distinct from + zero-dimensional arrays (e.g. ``np.array(0)``), and :func:`numpy.isscalar` + returns ``True`` for the former and ``False`` for the latter. + + JAX does not define special scalar objects, but rather represents scalars as + zero-dimensional arrays. As such, :func:`jax.numpy.isscalar` returns ``True`` + for both scalar objects (e.g. ``0.0`` or ``np.float32(0.0)``) and array-like + objects with zero dimensions (e.g. ``jnp.array(0.0)``, ``np.array(0.0)``). + + One reason for the different conventions in ``isscalar`` is to maintain + JIT-invariance: i.e. the property that the result of a function should not + change when it is JIT-compiled. Because scalar inputs are cast to + zero-dimensional JAX arrays at JIT boundaries, the semantics of + :func:`numpy.isscalar` are such that the result changes under JIT: + + >>> np.isscalar(1.0) + True + >>> jax.jit(np.isscalar)(1.0) + Array(False, dtype=bool) + + By treating zero-dimensional arrays as scalars, :func:`jax.numpy.isscalar` + avoids this issue: + + >>> jnp.isscalar(1.0) + True + >>> jax.jit(jnp.isscalar)(1.0) + Array(True, dtype=bool) + + Examples: + In JAX, both scalars and zero-dimensional array-like objects are considered + scalars: + + >>> jnp.isscalar(1.0) + True + >>> jnp.isscalar(1 + 1j) + True + >>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array + True + >>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor + True + >>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array + True + >>> jnp.isscalar(np.int32(1)) # NumPy scalar type + True + + Arrays with one or more dimension are not considered scalars: + + >>> jnp.isscalar(jnp.array([1])) + False + >>> jnp.isscalar(np.array([1])) + False + + Compare this to :func:`numpy.isscalar`, which returns ``True`` for + scalar-typed objects, and ``False`` for *all* arrays, even those with + zero dimensions: + + >>> np.isscalar(np.int32(1)) # scalar object + True + >>> np.isscalar(np.array(1)) # zero-dimensional array + False + + In JAX, as in NumPy, objects which are not array-like are not considered + scalars: + + >>> jnp.isscalar(None) + False + >>> jnp.isscalar([1]) + False + >>> jnp.isscalar(tuple()) + False + >>> jnp.isscalar(slice(10)) + False + """ + if (isinstance(element, (np.ndarray, jax.Array)) + or hasattr(element, '__jax_array__') + or np.isscalar(element)): + return asarray(element).ndim == 0 + return False iterable = np.iterable diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 145506dd4166..03e0250c6a57 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3775,6 +3775,14 @@ def testMemoryView(self): np.array(bytearray(b'\x2a\xf3'), ndmin=2) ) + @jtu.sample_product(value=[False, 1, 1.0, np.int32(5), np.array(16)]) + def testIsScalar(self, value): + self.assertTrue(jnp.isscalar(value)) + + @jtu.sample_product(value=[None, [1], slice(4), (), np.array([0])]) + def testIsNotScalar(self, value): + self.assertFalse(jnp.isscalar(value)) + @jtu.sample_product(val=[1+1j, [1+1j], jnp.pi, np.arange(2)]) def testIsComplexObj(self, val): args_maker = lambda: [val]