Skip to content

Commit

Permalink
Simplify definition of jnp.isscalar
Browse files Browse the repository at this point in the history
The new semantics are to return True for any array-like object with zero dimensions.
Previously we only returned True for zero-dimensional array-like objects with a weak type. This ends up being more confusing/suprising than it needs to be, and the weak type dependence is rarely useful in practice.

PiperOrigin-RevId: 682656411
  • Loading branch information
Jake VanderPlas authored and Google-ML-Automation committed Oct 5, 2024
1 parent e90487e commit 45f0e9a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 4 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## jax 0.4.35

* Breaking Changes
* {func}`jax.numpy.isscalar` now returns True for any array-like object with
zero dimensions. Previously it only returned True for zero-dimensional
array-like objects with a weak dtype.

## jax 0.4.34 (October 4, 2023)

* New Functionality
Expand Down
98 changes: 94 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,101 @@ def issubdtype(arg1: DTypeLike, arg2: DTypeLike) -> bool:
"""
return dtypes.issubdtype(arg1, arg2)

@util.implements(np.isscalar)

def isscalar(element: Any) -> bool:
if hasattr(element, '__jax_array__'):
element = element.__jax_array__()
return dtypes.is_python_scalar(element) or np.isscalar(element)
"""Return True if the input is a scalar.
JAX implementation of :func:`numpy.isscalar`. JAX's implementation differs
from NumPy's in that it considers zero-dimensional arrays to be scalars; see
the *Note* below for more details.
Args:
element: input object to check; any type is valid input.
Returns:
True if ``element`` is a scalar value or an array-like object with zero
dimensions, False otherwise.
Note:
JAX and NumPy differ in their representation of scalar values. NumPy has
special scalar objects (e.g. ``np.int32(0)``) which are distinct from
zero-dimensional arrays (e.g. ``np.array(0)``), and :func:`numpy.isscalar`
returns ``True`` for the former and ``False`` for the latter.
JAX does not define special scalar objects, but rather represents scalars as
zero-dimensional arrays. As such, :func:`jax.numpy.isscalar` returns ``True``
for both scalar objects (e.g. ``0.0`` or ``np.float32(0.0)``) and array-like
objects with zero dimensions (e.g. ``jnp.array(0.0)``, ``np.array(0.0)``).
One reason for the different conventions in ``isscalar`` is to maintain
JIT-invariance: i.e. the property that the result of a function should not
change when it is JIT-compiled. Because scalar inputs are cast to
zero-dimensional JAX arrays at JIT boundaries, the semantics of
:func:`numpy.isscalar` are such that the result changes under JIT:
>>> np.isscalar(1.0)
True
>>> jax.jit(np.isscalar)(1.0)
Array(False, dtype=bool)
By treating zero-dimensional arrays as scalars, :func:`jax.numpy.isscalar`
avoids this issue:
>>> jnp.isscalar(1.0)
True
>>> jax.jit(jnp.isscalar)(1.0)
Array(True, dtype=bool)
Examples:
In JAX, both scalars and zero-dimensional array-like objects are considered
scalars:
>>> jnp.isscalar(1.0)
True
>>> jnp.isscalar(1 + 1j)
True
>>> jnp.isscalar(jnp.array(1)) # zero-dimensional JAX array
True
>>> jnp.isscalar(jnp.int32(1)) # JAX scalar constructor
True
>>> jnp.isscalar(np.array(1.0)) # zero-dimensional NumPy array
True
>>> jnp.isscalar(np.int32(1)) # NumPy scalar type
True
Arrays with one or more dimension are not considered scalars:
>>> jnp.isscalar(jnp.array([1]))
False
>>> jnp.isscalar(np.array([1]))
False
Compare this to :func:`numpy.isscalar`, which returns ``True`` for
scalar-typed objects, and ``False`` for *all* arrays, even those with
zero dimensions:
>>> np.isscalar(np.int32(1)) # scalar object
True
>>> np.isscalar(np.array(1)) # zero-dimensional array
False
In JAX, as in NumPy, objects which are not array-like are not considered
scalars:
>>> jnp.isscalar(None)
False
>>> jnp.isscalar([1])
False
>>> jnp.isscalar(tuple())
False
>>> jnp.isscalar(slice(10))
False
"""
if (isinstance(element, (np.ndarray, jax.Array))
or hasattr(element, '__jax_array__')
or np.isscalar(element)):
return asarray(element).ndim == 0
return False

iterable = np.iterable

Expand Down
8 changes: 8 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3775,6 +3775,14 @@ def testMemoryView(self):
np.array(bytearray(b'\x2a\xf3'), ndmin=2)
)

@jtu.sample_product(value=[False, 1, 1.0, np.int32(5), np.array(16)])
def testIsScalar(self, value):
self.assertTrue(jnp.isscalar(value))

@jtu.sample_product(value=[None, [1], slice(4), (), np.array([0])])
def testIsNotScalar(self, value):
self.assertFalse(jnp.isscalar(value))

@jtu.sample_product(val=[1+1j, [1+1j], jnp.pi, np.arange(2)])
def testIsComplexObj(self, val):
args_maker = lambda: [val]
Expand Down

0 comments on commit 45f0e9a

Please sign in to comment.