diff --git a/jax/_src/array.py b/jax/_src/array.py index 69f8960e3eef..7ae7ec6e6075 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -376,8 +376,11 @@ def __repr__(self): if self.is_fully_addressable or self.is_fully_replicated: line_width = np.get_printoptions()["linewidth"] - s = np.array2string(self._value, prefix=prefix, suffix=',', - separator=', ', max_line_width=line_width) + if self.size == 0: + s = f"[], shape={self.shape}" + else: + s = np.array2string(self._value, prefix=prefix, suffix=',', + separator=', ', max_line_width=line_width) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' if last_line_len + len(dtype_str) + 1 > line_width: diff --git a/tests/array_test.py b/tests/array_test.py index 9e9b61915cf5..a1239bbdb015 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -231,6 +231,12 @@ def test_repr(self): input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) self.assertStartsWith(repr(arr), "Array(") + def test_empty_repr(self): + shape = (0, 5) + dtype = 'float32' + x = jnp.empty(shape, dtype) + self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})") + def test_jnp_array(self): arr = jnp.array([1, 2, 3]) self.assertIsInstance(arr, array.ArrayImpl)