From d9cbd7bd5e729efabe4c55f40a1a12ff20a0df9e Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 5 Feb 2024 13:18:33 -0800 Subject: [PATCH] Improve repr for empty jax.Array --- jax/_src/array.py | 7 +++++-- tests/array_test.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index 69f8960e3eef..7ae7ec6e6075 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -376,8 +376,11 @@ def __repr__(self): if self.is_fully_addressable or self.is_fully_replicated: line_width = np.get_printoptions()["linewidth"] - s = np.array2string(self._value, prefix=prefix, suffix=',', - separator=', ', max_line_width=line_width) + if self.size == 0: + s = f"[], shape={self.shape}" + else: + s = np.array2string(self._value, prefix=prefix, suffix=',', + separator=', ', max_line_width=line_width) last_line_len = len(s) - s.rfind('\n') + 1 sep = ' ' if last_line_len + len(dtype_str) + 1 > line_width: diff --git a/tests/array_test.py b/tests/array_test.py index 9e9b61915cf5..a1239bbdb015 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -231,6 +231,12 @@ def test_repr(self): input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) self.assertStartsWith(repr(arr), "Array(") + def test_empty_repr(self): + shape = (0, 5) + dtype = 'float32' + x = jnp.empty(shape, dtype) + self.assertEqual(repr(x), f"Array([], shape={shape}, dtype={dtype})") + def test_jnp_array(self): arr = jnp.array([1, 2, 3]) self.assertIsInstance(arr, array.ArrayImpl)