Skip to content

Commit

Permalink
[array api] add jax.numpy.concat
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 11, 2024
1 parent 35fc2ed commit c83c0f3
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ namespace; they are listed below.
complexfloating
ComplexWarning
compress
concat
concatenate
conj
conjugate
Expand Down
10 changes: 8 additions & 2 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1869,10 +1869,10 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
util.check_arraylike("concatenate", *arrays)
if not len(arrays):
raise ValueError("Need at least one array to concatenate.")
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
if axis is None:
return concatenate([ravel(a) for a in arrays], axis=0, dtype=dtype)
if ndim(arrays[0]) == 0:
raise ValueError("Zero-dimensional arrays cannot be concatenated.")
axis = _canonicalize_axis(axis, ndim(arrays[0]))
if dtype is None:
arrays_out = util.promote_dtypes(*arrays)
Expand All @@ -1888,6 +1888,12 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike],
return arrays_out[0]


@util._wraps(getattr(np, "concat", None))
def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array:
util.check_arraylike("concat", *arrays)
return jax.numpy.concatenate(arrays, axis=axis)


@util._wraps(np.vstack)
def vstack(tup: np.ndarray | Array | Sequence[ArrayLike],
dtype: DTypeLike | None = None) -> Array:
Expand Down
5 changes: 1 addition & 4 deletions jax/experimental/array_api/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ def broadcast_to(x: Array, /, shape: tuple[int]) -> Array:
def concat(arrays: tuple[Array, ...] | list[Array], /, *, axis: int | None = 0) -> Array:
"""Joins a sequence of arrays along an existing axis."""
dtype = _result_type(*arrays)
if axis is None:
arrays = [reshape(arr, (arr.size,)) for arr in arrays]
axis = 0
return jax.numpy.concatenate(arrays, axis=axis, dtype=dtype)
return jax.numpy.concat([arr.astype(dtype) for arr in arrays], axis=axis)


def expand_dims(x: Array, /, *, axis: int = 0) -> Array:
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
complex_ as complex_,
complexfloating as complexfloating,
compress as compress,
concat as concat,
concatenate as concatenate,
convolve as convolve,
copy as copy,
Expand Down
1 change: 1 addition & 0 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ complex_: Any
complexfloating = _np.complexfloating
def compress(condition: ArrayLike, a: ArrayLike, axis: Optional[int] = ...,
out: None = ...) -> Array: ...
def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: ...
def concatenate(
arrays: Union[_np.ndarray, Array, Sequence[ArrayLike]],
axis: Optional[int] = ...,
Expand Down
32 changes: 30 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,7 @@ def testCompressMethod(self, shape, dtype, axis):
@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
for axis in range(-len(base_shape)+1, len(base_shape))
for axis in (None, *range(-len(base_shape)+1, len(base_shape)))
],
arg_dtypes=[
arg_dtypes
Expand All @@ -1482,7 +1482,7 @@ def testCompressMethod(self, shape, dtype, axis):
)
def testConcatenate(self, axis, dtype, base_shape, arg_dtypes):
rng = jtu.rand_default(self.rng())
wrapped_axis = axis % len(base_shape)
wrapped_axis = 0 if axis is None else axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size, _ in zip(itertools.cycle([3, 1, 4]), arg_dtypes)]
@jtu.promote_like_jnp
Expand Down Expand Up @@ -1521,6 +1521,34 @@ def testConcatenateAxisNone(self):
b = jnp.array([[5]])
jnp.concatenate((a, b), axis=None)

def testConcatenateScalarAxisNone(self):
arrays = [np.int32(0), np.int32(1)]
self.assertArraysEqual(jnp.concatenate(arrays, axis=None),
np.concatenate(arrays, axis=None))

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(), (4,), (3, 4), (2, 3, 4)]
for axis in (None, *range(-len(base_shape)+1, len(base_shape)))
],
dtype=default_dtypes,
)
def testConcat(self, axis, base_shape, dtype):
rng = jtu.rand_default(self.rng())
wrapped_axis = 0 if axis is None else axis % len(base_shape)
shapes = [base_shape[:wrapped_axis] + (size,) + base_shape[wrapped_axis+1:]
for size in [3, 1, 4]]
@jtu.promote_like_jnp
def np_fun(*args):
if jtu.numpy_version() >= (2, 0, 0):
return np.concat(args, axis=axis)
else:
return np.concatenate(args, axis=axis)
jnp_fun = lambda *args: jnp.concat(args, axis=axis)
args_maker = lambda: [rng(shape, dtype) for shape in shapes]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(base_shape=base_shape, axis=axis)
for base_shape in [(4,), (3, 4), (2, 3, 4)]
Expand Down

0 comments on commit c83c0f3

Please sign in to comment.