diff --git a/cubed/__init__.py b/cubed/__init__.py index 31c56240..048fea5b 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -15,7 +15,7 @@ from .core.array import compute, measure_reserved_mem, visualize from .core.gufunc import apply_gufunc -from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr +from .core.ops import from_array, from_zarr, map_blocks, rechunk, store, to_zarr from .nan_functions import nanmean, nansum from .overlap import map_overlap from .pad import pad @@ -38,6 +38,7 @@ "nanmean", "nansum", "pad", + "rechunk", "store", "to_zarr", "visualize", diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 23ea169b..cb8610f3 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -1024,6 +1024,11 @@ def rechunk(x, chunks, target_store=None): return x # normalizing takes care of dict args for chunks target_chunks = to_chunksize(normalized_chunks) + + # merge chunks special case + if all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunks)): + return merge_chunks(x, target_chunks) + name = gensym() spec = x.spec if target_store is None: diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index ba916517..efba6ffe 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -298,6 +298,7 @@ def test_rechunk(spec, executor, new_chunks, expected_chunks): def test_rechunk_same_chunks(spec): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec) b = a.rechunk((2, 1)) + assert b is a task_counter = TaskCounter() res = b.compute(callbacks=[task_counter]) # no tasks should have run since chunks are same @@ -314,6 +315,31 @@ def test_rechunk_intermediate(tmp_path): assert_array_equal(b.compute(), np.ones((4, 4))) intermediates = [n for (n, d) in b.plan.dag.nodes(data=True) if "-int" in d["name"]] assert len(intermediates) == 1 + rechunks = [ + n + for (n, d) in b.plan.dag.nodes(data=True) + if d.get("op_name", None) == "rechunk" + ] + assert len(rechunks) > 0 + + +def test_rechunk_merge_chunks_optimization(): + a = xp.asarray( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + chunks=(2, 1), + ) + b = a.rechunk((4, 2)) + assert b.chunks == ((4,), (2, 2)) + assert_array_equal( + b.compute(), + np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]), + ) + rechunks = [ + n + for (n, d) in b.plan.dag.nodes(data=True) + if d.get("op_name", None) == "rechunk" + ] + assert len(rechunks) == 0 def test_compute_is_idempotent(spec, executor): diff --git a/docs/api.rst b/docs/api.rst index c2bf32b9..6afacf4f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -46,6 +46,7 @@ Chunk-specific functions apply_gufunc map_blocks map_overlap + rechunk Non-standardised functions ==========================