Skip to content

Commit

Permalink
Merge pull request #18783 from gnecula:fix_indexing
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587328492
  • Loading branch information
jax authors committed Dec 2, 2023
2 parents 32fb1b4 + d2f6261 commit b822801
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
46 changes: 26 additions & 20 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4936,6 +4936,9 @@ def _preprocess_slice(
) -> tuple[core.DimSize, core.DimSize, core.DimSize]:
"""Computes the start index, step, and size of the slice `x[s]`."""
# See https://numpy.org/doc/stable/user/basics.indexing.html#slicing-and-striding
# "this is harder to get right than you may think"
# (from https://github.com/python/cpython/blob/939fc6d6eab9b7ea8c244d513610dbdd556503a7/Objects/sliceobject.c#L275)

# Must resolve statically if step is {<0, ==0, >0}
step = s.step if s.step is not None else 1
try:
Expand All @@ -4946,35 +4949,38 @@ def _preprocess_slice(
raise core.InconclusiveDimensionOperation(
f"In slice with non-constant elements the step ({step}) must " +
f"be resolved statically if it is > 0 or < 0.\nDetails: {e}")
if s.start is None:
start = 0 if step_gt_0 else axis_size - 1
else:
start = s.start

def clamp_index(i: DimSize, which: str):
try:
start_ge_0 = (start >= 0)
i_ge_0 = (i >= 0)
except core.InconclusiveDimensionOperation as e:
raise core.InconclusiveDimensionOperation(
f"In slice with non-constant elements the start ({start}) must " +
f"In slice with non-constant elements the {which} ({i}) must " +
f"be resolved statically if it is >= 0.\nDetails: {e}")
if start_ge_0:
start = axis_size - core.non_negative_dim(axis_size - start) # min(axis_size, start)
if i_ge_0:
if step_gt_0:
# min(i, axis_size)
return axis_size - core.non_negative_dim(axis_size - i)
else:
# min(i, axis_size - 1)
return axis_size - 1 - core.non_negative_dim(axis_size - 1 - i)
else:
start = core.non_negative_dim(axis_size + start) # max(axis_size + start, 0)
if step_gt_0:
# max(axis_size + i, 0)
return core.non_negative_dim(axis_size + i)
else:
# max(axis_size + i, -1)
return -1 + core.non_negative_dim(axis_size + i + 1)

if s.start is None:
start = 0 if step_gt_0 else axis_size - 1
else:
start = clamp_index(s.start, "start")

if s.stop is None:
stop = axis_size if step_gt_0 else -1
else:
stop = s.stop
try:
stop_ge_0 = (stop >= 0)
except core.InconclusiveDimensionOperation as e:
raise core.InconclusiveDimensionOperation(
f"In slice with non-constant elements the stop ({stop}) must " +
f"be resolved statically if it is >= 0.\nDetails: {e}")
if stop_ge_0:
stop = axis_size - core.non_negative_dim(axis_size - stop) # min(axis_size, stop)
else:
stop = core.non_negative_dim(axis_size + stop) # max(axis_size + stop, 0)
stop = clamp_index(s.stop, "stop")

gap = step if step_gt_0 else - step
distance = (stop - start) if step_gt_0 else (start - stop)
Expand Down
7 changes: 7 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def check_grads(f, args, order, atol=None, rtol=None, eps=None):
IndexSpec(shape=(10, 8), indexer=slice(0, 8, -1), out_shape=(0, 8)),
IndexSpec(shape=(10, 8), indexer=slice(None, None, -1), out_shape=(10, 8)),
]),
("SliceIndexClamping", [
IndexSpec(shape=(10,), indexer=slice(2, 11, 1), out_shape=(8,)),
IndexSpec(shape=(10,), indexer=slice(11, 12, 1), out_shape=(0,)),
IndexSpec(shape=(10,), indexer=slice(-11, -2, 1), out_shape=(8,)),
IndexSpec(shape=(10,), indexer=slice(-2, -12, -1), out_shape=(9,)),
IndexSpec(shape=(10,), indexer=slice(12, -12, -1), out_shape=(10,)),
]),
("OneSliceIndexNonUnitStride", [
IndexSpec(shape=(10,), indexer=slice(0, 8, 2), out_shape=(4,)),
IndexSpec(shape=(10,), indexer=slice(0, 8, 3), out_shape=(3,)),
Expand Down
7 changes: 7 additions & 0 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from jax._src import config
from jax._src import core
from jax._src import test_util as jtu
from jax._src import tree_util
from jax._src.lax import lax as lax_internal
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lib import xla_client
Expand Down Expand Up @@ -531,6 +532,7 @@ def log_message(extra: str):

f_jax = self.dyn_fun
args = self.dyn_args_maker(tst.rng())
args = tree_util.tree_map(jnp.array, args)
args_specs = export.args_specs(args, self.polymorphic_shapes)

if self.expect_error is not None:
Expand Down Expand Up @@ -1199,21 +1201,26 @@ def test_vmap_error(self):
arg_descriptors=[RandArg((16, 8), np.float32)],
polymorphic_shapes=["c, b"])
# start, stop, step are functions that take the argument "b"
# "b" actual value is 8 and "c" is 16
for start_name, start in [
("None", lambda b: None),
("0", lambda b: 0),
("2", lambda b: 2),
("b", lambda b: b),
("2b2", lambda b: 2*b + 2),
("-2", lambda b: -2),
("-b", lambda b: -b),
("-2b2", lambda b: -2*b - 2),
]
for stop_name, stop in [
("None", lambda b: None),
("0", lambda b: 0),
("b", lambda b: b),
("2b2", lambda b: 2*b + 2),
("4", lambda b: 4),
("-4", lambda b: -4),
("-b", lambda b: -b),
("-2b2", lambda b: -2*b - 2),
]
for step_name, step in [
("None", lambda b: None),
Expand Down

0 comments on commit b822801

Please sign in to comment.