Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defer to merge_chunks in special cases of rechunk #612

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cubed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .core.array import compute, measure_reserved_mem, visualize
from .core.gufunc import apply_gufunc
from .core.ops import from_array, from_zarr, map_blocks, store, to_zarr
from .core.ops import from_array, from_zarr, map_blocks, rechunk, store, to_zarr
from .nan_functions import nanmean, nansum
from .overlap import map_overlap
from .pad import pad
Expand All @@ -38,6 +38,7 @@
"nanmean",
"nansum",
"pad",
"rechunk",
"store",
"to_zarr",
"visualize",
Expand Down
5 changes: 5 additions & 0 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,11 @@ def rechunk(x, chunks, target_store=None):
return x
# normalizing takes care of dict args for chunks
target_chunks = to_chunksize(normalized_chunks)

# merge chunks special case
if all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunks)):
return merge_chunks(x, target_chunks)

name = gensym()
spec = x.spec
if target_store is None:
Expand Down
26 changes: 26 additions & 0 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def test_rechunk(spec, executor, new_chunks, expected_chunks):
def test_rechunk_same_chunks(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 1), spec=spec)
b = a.rechunk((2, 1))
assert b is a
task_counter = TaskCounter()
res = b.compute(callbacks=[task_counter])
# no tasks should have run since chunks are same
Expand All @@ -314,6 +315,31 @@ def test_rechunk_intermediate(tmp_path):
assert_array_equal(b.compute(), np.ones((4, 4)))
intermediates = [n for (n, d) in b.plan.dag.nodes(data=True) if "-int" in d["name"]]
assert len(intermediates) == 1
rechunks = [
n
for (n, d) in b.plan.dag.nodes(data=True)
if d.get("op_name", None) == "rechunk"
]
assert len(rechunks) > 0


def test_rechunk_merge_chunks_optimization():
a = xp.asarray(
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
chunks=(2, 1),
)
b = a.rechunk((4, 2))
assert b.chunks == ((4,), (2, 2))
assert_array_equal(
b.compute(),
np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]),
)
rechunks = [
n
for (n, d) in b.plan.dag.nodes(data=True)
if d.get("op_name", None) == "rechunk"
]
assert len(rechunks) == 0


def test_compute_is_idempotent(spec, executor):
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Chunk-specific functions
apply_gufunc
map_blocks
map_overlap
rechunk

Non-standardised functions
==========================
Expand Down
Loading