diff --git a/jax/_src/array.py b/jax/_src/array.py index 7ae7ec6e6075..15680c399160 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -308,13 +308,6 @@ def __getitem__(self, idx): from jax._src.numpy import lax_numpy self._check_if_deleted() - if isinstance(idx, tuple): - num_idx = sum(e is not None and e is not Ellipsis for e in idx) - if num_idx > self.ndim: - raise IndexError( - f"Too many indices for array: array has ndim of {self.ndim}, but " - f"was indexed with {num_idx} non-None/Ellipsis indices.") - if isinstance(self.sharding, PmapSharding): if config.pmap_no_rank_reduction.value: cidx = idx if isinstance(idx, tuple) else (idx,) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 481ba190991b..6502463586cf 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -4613,6 +4613,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, if isinstance(fill_value, np.ndarray): fill_value = fill_value.item() + if indexer.scalar_bool_dims: + y = lax.expand_dims(y, indexer.scalar_bool_dims) + # Avoid calling gather if the slice shape is empty, both as a fast path and to # handle cases like zeros(0)[array([], int32)]. if core.is_empty_shape(indexer.slice_shape): @@ -4657,6 +4660,10 @@ class _Indexer(NamedTuple): # gathers and eliminated for scatters. newaxis_dims: Sequence[int] + # Keep track of dimensions with scalar bool indices. These must be inserted + # for gathers before performing other index operations. + scalar_bool_dims: Sequence[int] + def _split_index_for_jit(idx, shape): """Splits indices into necessarily-static and dynamic parts. @@ -4705,6 +4712,16 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], # Remove ellipses and add trailing slice(None)s. idx = _canonicalize_tuple_index(len(x_shape), idx) + # Check for scalar boolean indexing: this requires inserting extra dimensions + # before performing the rest of the logic. + scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)] + if scalar_bool_dims: + idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx) + x_shape = list(x_shape) + for i in sorted(scalar_bool_dims): + x_shape.insert(i, 1) + x_shape = tuple(x_shape) + # Check for advanced indexing: # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing @@ -4805,8 +4822,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], # XLA gives error when indexing into an axis of size 0 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i - i = lax.convert_element_type(i, index_dtype) - gather_indices.append((i, len(gather_indices_shape))) + i_converted = lax.convert_element_type(i, index_dtype) + gather_indices.append((i_converted, len(gather_indices_shape))) collapsed_slice_dims.append(x_axis) gather_slice_shape.append(1) start_index_map.append(x_axis) @@ -4816,7 +4833,6 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], slice_shape.append(1) newaxis_dims.append(y_axis) y_axis += 1 - elif isinstance(i, slice): # Handle slice index (only static, otherwise an error is raised) if not all(_is_slice_element_none_or_constant_or_symbolic(elt) @@ -4893,7 +4909,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any], dnums=dnums, gather_indices=gather_indices_array, unique_indices=advanced_indexes is None, - indices_are_sorted=advanced_indexes is None) + indices_are_sorted=advanced_indexes is None, + scalar_bool_dims=scalar_bool_dims) def _should_unpack_list_index(x): """Helper for _eliminate_deprecated_list_indexing.""" @@ -4959,7 +4976,7 @@ def _expand_bool_indices(idx, shape): # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete raise errors.NonConcreteBooleanIndexError(abstract_i) elif _ndim(i) == 0: - raise TypeError("JAX arrays do not support boolean scalar indices") + out.append(bool(i)) else: i_shape = _shape(i) start = len(out) + ellipsis_offset - newaxis_offset @@ -5010,10 +5027,10 @@ def _is_scalar(x): def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'): """Helper to remove Ellipsis and add in the implicit trailing slice(None).""" - len_without_none = sum(e is not None and e is not Ellipsis for e in idx) - if len_without_none > arr_ndim: + num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx) + if num_dimensions_consumed > arr_ndim: raise IndexError( - f"Too many indices for {array_name}: {len_without_none} " + f"Too many indices for {array_name}: {num_dimensions_consumed} " f"non-None/Ellipsis indices for dim {arr_ndim}.") ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis) ellipsis_index = next(ellipses, None) @@ -5021,10 +5038,10 @@ def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'): if next(ellipses, None) is not None: raise IndexError( f"Multiple ellipses (...) not supported: {list(map(type, idx))}.") - colons = (slice(None),) * (arr_ndim - len_without_none) + colons = (slice(None),) * (arr_ndim - num_dimensions_consumed) idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:] - elif len_without_none < arr_ndim: - colons = (slice(None),) * (arr_ndim - len_without_none) + elif num_dimensions_consumed < arr_ndim: + colons = (slice(None),) * (arr_ndim - num_dimensions_consumed) idx = tuple(idx) + colons return idx diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 17bb5027fcd7..5934ae415610 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -103,6 +103,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx, idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) indexer = jnp._index_to_gather(jnp.shape(x), idx, normalize_indices=normalize_indices) + # TODO(jakevdp): implement scalar boolean logic. + if indexer.scalar_bool_dims: + raise TypeError("Scalar boolean indices are not allowed in scatter.") # Avoid calling scatter if the slice shape is empty, both as a fast path and # to handle cases like zeros(0)[array([], int32)]. diff --git a/jax/experimental/array_api/skips.txt b/jax/experimental/array_api/skips.txt index 64cf979a177e..76149272ae83 100644 --- a/jax/experimental/array_api/skips.txt +++ b/jax/experimental/array_api/skips.txt @@ -1,8 +1,5 @@ # Known failures for the array api tests. -# JAX doesn't yet support scalar boolean indexing -array_api_tests/test_array_object.py::test_getitem_masking - # Test suite attempts in-place mutation: array_api_tests/test_special_cases.py::test_binary array_api_tests/test_special_cases.py::test_iop diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 9e326c5f0e82..f5029226fbe7 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -876,15 +876,6 @@ def testBooleanIndexingDynamicShapeError(self): i = np.array([True, True, False]) self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i)) - def testScalarBooleanIndexingNotImplemented(self): - msg = "JAX arrays do not support boolean scalar indices" - with self.assertRaisesRegex(TypeError, msg): - jnp.arange(4)[True] - with self.assertRaisesRegex(TypeError, msg): - jnp.arange(4)[False] - with self.assertRaisesRegex(TypeError, msg): - jnp.arange(4)[..., True] - def testIssue187(self): x = jnp.ones((5, 5)) x[[0, 2, 4], [0, 2, 4]] # doesn't crash @@ -1033,6 +1024,29 @@ def testNontrivialBooleanIndexing(self): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + shape=[(2, 3, 4, 5)], + idx=[ + np.index_exp[True], + np.index_exp[False], + np.index_exp[..., True], + np.index_exp[..., False], + np.index_exp[0, :2, True], + np.index_exp[0, :2, False], + np.index_exp[:2, 0, True], + np.index_exp[:2, 0, False], + np.index_exp[:2, np.array([0, 2]), True], + np.index_exp[np.array([1, 0]), :, True], + np.index_exp[True, :, True, :, np.array(True)], + ] + ) + def testScalarBooleanIndexing(self, shape, idx): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, np.int32)] + np_fun = lambda x: np.asarray(x)[idx] + jnp_fun = lambda x: jnp.asarray(x)[idx] + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + def testFloatIndexingError(self): BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type" with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR): @@ -1158,8 +1172,7 @@ def _check_raises(x_type, y_type, msg): def testWrongNumberOfIndices(self): with self.assertRaisesRegex( IndexError, - "Too many indices for array: array has ndim of 1, " - "but was indexed with 2 non-None/Ellipsis indices"): + "Too many indices for array: 2 non-None/Ellipsis indices for dim 1."): jnp.zeros(3)[:, 5]