From 55c632fd801df06176d0f3dbb4fe80f3a43afd85 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 13 Nov 2024 10:02:28 +0000 Subject: [PATCH] Unify chunks for `concat`, and check preconditions (#616) --- cubed/array_api/manipulation_functions.py | 40 ++++++++++++++++++++--- cubed/tests/test_array_api.py | 37 +++++++++++++++++++++ 2 files changed, 72 insertions(+), 5 deletions(-) diff --git a/cubed/array_api/manipulation_functions.py b/cubed/array_api/manipulation_functions.py index 284cda0e..4948d203 100644 --- a/cubed/array_api/manipulation_functions.py +++ b/cubed/array_api/manipulation_functions.py @@ -83,25 +83,55 @@ def concat(arrays, /, *, axis=0, chunks=None): if not arrays: raise ValueError("Need array(s) to concat") + if len({a.dtype for a in arrays}) > 1: + raise ValueError("concat inputs must all have the same dtype") + if axis is None: arrays = [flatten(array) for array in arrays] axis = 0 - # TODO: check arrays all have same shape (except in the dimension specified by axis) - # TODO: type promotion - # TODO: unify chunks + if len(arrays) == 1: + return arrays[0] a = arrays[0] + # check arrays all have same shape (except in the dimension specified by axis) + ndim = a.ndim + if not all( + i == axis or all(x.shape[i] == arrays[0].shape[i] for x in arrays) + for i in range(ndim) + ): + raise ValueError( + f"all the input array dimensions except for the concatenation axis must match exactly: {[x.shape for x in arrays]}" + ) + + # check arrays all have the same chunk size along axis (if more than one chunk) + if len({a.chunksize[axis] for a in arrays if a.numblocks[axis] > 1}) > 1: + raise ValueError( + f"all the input array chunk sizes must match along the concatenation axis: {[x.chunksize[axis] for x in arrays]}" + ) + + # unify chunks (except in the dimension specified by axis) + inds = [list(range(x.ndim)) for x in arrays] + for i, ind in enumerate(inds): + ind[axis] = -(i + 1) + uc_args = tlz.concat(zip(arrays, inds)) + chunkss, arrays = unify_chunks(*uc_args, warn=False) + # offsets along axis for the start of each array offsets = [0] + list(tlz.accumulate(add, [a.shape[axis] for a in arrays])) in_shapes = tuple(array.shape for array in arrays) - axis = validate_axis(axis, a.ndim) + axis = validate_axis(axis, ndim) shape = a.shape[:axis] + (offsets[-1],) + a.shape[axis + 1 :] dtype = a.dtype if chunks is None: - chunks = normalize_chunks(to_chunksize(a.chunks), shape=shape, dtype=dtype) + # use unified chunks except for dimension specified by axis + axis_chunksize = max(a.chunksize[axis] for a in arrays) + chunksize = tuple( + axis_chunksize if i == axis else chunkss[i] for i in range(ndim) + ) + chunks = normalize_chunks(chunksize, shape=shape, dtype=dtype) else: chunks = normalize_chunks(chunks, shape=shape, dtype=dtype) diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 42a245cb..4ea6653c 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -506,6 +506,43 @@ def test_concat(spec, executor): ) +def test_concat_different_chunks(spec): + a = xp.asarray([[1], [5]], chunks=(2, 2), spec=spec) + b = xp.asarray([[2, 3, 4], [6, 7, 8]], chunks=(2, 3), spec=spec) + c = xp.concat([a, b], axis=1) + assert_array_equal( + c.compute(), + np.concatenate( + [ + np.array([[1], [5]]), + np.array([[2, 3, 4], [6, 7, 8]]), + ], + axis=1, + ), + ) + + +@pytest.mark.parametrize("axis", [None, 0]) +def test_concat_single_array(spec, axis): + a = xp.full((4, 5), 1, chunks=(3, 2), spec=spec) + d = xp.concat([a], axis=axis) + assert_array_equal( + d.compute(), + np.concatenate([np.full((4, 5), 1)], axis=axis), + ) + + +def test_concat_incompatible_shapes(spec): + a = xp.full((4, 5), 1, chunks=(3, 2), spec=spec) + b = xp.full((4, 6), 2, chunks=(3, 2), spec=spec) + with pytest.raises( + ValueError, + match="all the input array dimensions except for the concatenation axis must match exactly", + ): + xp.concat([a, b], axis=0) + xp.concat([a, b], axis=1) # OK + + def test_expand_dims(spec, executor): a = xp.asarray([1, 2, 3], chunks=(2,), spec=spec) b = xp.expand_dims(a, axis=0)