Skip to content

Commit

Permalink
Unify chunks for concat, and check preconditions (#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Nov 13, 2024
1 parent 8883717 commit 55c632f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 5 deletions.
40 changes: 35 additions & 5 deletions cubed/array_api/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,25 +83,55 @@ def concat(arrays, /, *, axis=0, chunks=None):
if not arrays:
raise ValueError("Need array(s) to concat")

if len({a.dtype for a in arrays}) > 1:
raise ValueError("concat inputs must all have the same dtype")

if axis is None:
arrays = [flatten(array) for array in arrays]
axis = 0

# TODO: check arrays all have same shape (except in the dimension specified by axis)
# TODO: type promotion
# TODO: unify chunks
if len(arrays) == 1:
return arrays[0]

a = arrays[0]

# check arrays all have same shape (except in the dimension specified by axis)
ndim = a.ndim
if not all(
i == axis or all(x.shape[i] == arrays[0].shape[i] for x in arrays)
for i in range(ndim)
):
raise ValueError(
f"all the input array dimensions except for the concatenation axis must match exactly: {[x.shape for x in arrays]}"
)

# check arrays all have the same chunk size along axis (if more than one chunk)
if len({a.chunksize[axis] for a in arrays if a.numblocks[axis] > 1}) > 1:
raise ValueError(
f"all the input array chunk sizes must match along the concatenation axis: {[x.chunksize[axis] for x in arrays]}"
)

# unify chunks (except in the dimension specified by axis)
inds = [list(range(x.ndim)) for x in arrays]
for i, ind in enumerate(inds):
ind[axis] = -(i + 1)
uc_args = tlz.concat(zip(arrays, inds))
chunkss, arrays = unify_chunks(*uc_args, warn=False)

# offsets along axis for the start of each array
offsets = [0] + list(tlz.accumulate(add, [a.shape[axis] for a in arrays]))
in_shapes = tuple(array.shape for array in arrays)

axis = validate_axis(axis, a.ndim)
axis = validate_axis(axis, ndim)
shape = a.shape[:axis] + (offsets[-1],) + a.shape[axis + 1 :]
dtype = a.dtype
if chunks is None:
chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype)
# use unified chunks except for dimension specified by axis
axis_chunksize = max(a.chunksize[axis] for a in arrays)
chunksize = tuple(
axis_chunksize if i == axis else chunkss[i] for i in range(ndim)
)
chunks = normalize_chunks(chunksize, shape=shape, dtype=dtype)
else:
chunks = normalize_chunks(chunks, shape=shape, dtype=dtype)

Expand Down
37 changes: 37 additions & 0 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,43 @@ def test_concat(spec, executor):
)


def test_concat_different_chunks(spec):
a = xp.asarray([[1], [5]], chunks=(2, 2), spec=spec)
b = xp.asarray([[2, 3, 4], [6, 7, 8]], chunks=(2, 3), spec=spec)
c = xp.concat([a, b], axis=1)
assert_array_equal(
c.compute(),
np.concatenate(
[
np.array([[1], [5]]),
np.array([[2, 3, 4], [6, 7, 8]]),
],
axis=1,
),
)


@pytest.mark.parametrize("axis", [None, 0])
def test_concat_single_array(spec, axis):
a = xp.full((4, 5), 1, chunks=(3, 2), spec=spec)
d = xp.concat([a], axis=axis)
assert_array_equal(
d.compute(),
np.concatenate([np.full((4, 5), 1)], axis=axis),
)


def test_concat_incompatible_shapes(spec):
a = xp.full((4, 5), 1, chunks=(3, 2), spec=spec)
b = xp.full((4, 6), 2, chunks=(3, 2), spec=spec)
with pytest.raises(
ValueError,
match="all the input array dimensions except for the concatenation axis must match exactly",
):
xp.concat([a, b], axis=0)
xp.concat([a, b], axis=1) # OK


def test_expand_dims(spec, executor):
a = xp.asarray([1, 2, 3], chunks=(2,), spec=spec)
b = xp.expand_dims(a, axis=0)
Expand Down

0 comments on commit 55c632f

Please sign in to comment.