Skip to content

Commit

Permalink
Merge pull request #18566 from mattjj:jnp-reshape-type-error
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583148860
  • Loading branch information
jax authors committed Nov 16, 2023
2 parents 6299ff8 + 0b046fb commit d60014c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,25 @@ def _compute_newshape(a: ArrayLike, newshape: Union[DimSize, Shape]) -> Shape:
newshape = core.canonicalize_shape(newshape) # type: ignore[arg-type]
neg1s = [i for i, d in enumerate(newshape) if type(d) is int and d == -1]
if len(neg1s) > 1:
raise ValueError("can only specify one unknown axis size with a `-1` value, "
f"got {orig_newshape}")
raise TypeError("can only specify one unknown axis size with a `-1` value, "
f"got {orig_newshape}")
if neg1s:
i, = neg1s
other_sizes = (*newshape[:i], *newshape[i+1:])
if (all(isinstance(d, int) for d in (*np.shape(a), *other_sizes)) and
np.size(a) % math.prod(other_sizes) != 0):
raise ValueError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
f"into shape {orig_newshape} because the product of "
f"specified axis sizes ({math.prod(other_sizes)}) does "
f"not evenly divide {np.size(a)}")
raise TypeError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
f"into shape {orig_newshape} because the product of "
f"specified axis sizes ({math.prod(other_sizes)}) does "
f"not evenly divide {np.size(a)}")
sz = core.cancel_divide_tracers(np.shape(a), other_sizes)
if sz is not None:
return (*newshape[:i], sz, *newshape[i+1:])
else:
if (all(isinstance(d, int) for d in (*np.shape(a), *newshape)) and
np.size(a) != math.prod(newshape)):
raise ValueError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
f"into shape {orig_newshape} (size {math.prod(newshape)})")
raise TypeError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) "
f"into shape {orig_newshape} (size {math.prod(newshape)})")
return tuple(-core.divide_shape_sizes(np.shape(a), newshape)
if core.definitely_equal(d, -1) else d for d in newshape)

Expand Down

0 comments on commit d60014c

Please sign in to comment.