From 05583bdfa2c3f3af119381c3114a75a4d47e81eb Mon Sep 17 00:00:00 2001 From: Tom White Date: Sat, 13 Jan 2024 15:42:35 +0000 Subject: [PATCH 1/8] Add num_input_blocks to primitive operation for bookkeeping when fusing operations with different numbers of tasks --- cubed/core/ops.py | 6 ++++++ cubed/core/plan.py | 4 ++++ cubed/primitive/blockwise.py | 15 +++++++++++++++ cubed/primitive/types.py | 5 ++++- 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index bb71c1b2..274b8ef9 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -266,6 +266,7 @@ def blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) fusable = kwargs.pop("fusable", True) + num_input_blocks = kwargs.pop("num_input_blocks", (1,) * len(source_arrays)) name = gensym() spec = check_array_specs(arrays) @@ -287,6 +288,7 @@ def blockwise( out_name=name, extra_func_kwargs=extra_func_kwargs, fusable=fusable, + num_input_blocks=num_input_blocks, **kwargs, ) plan = Plan._new( @@ -324,6 +326,8 @@ def general_blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) + num_input_blocks = kwargs.pop("num_input_blocks", (1,) * len(source_arrays)) + name = gensym() spec = check_array_specs(arrays) if target_store is None: @@ -341,6 +345,7 @@ def general_blockwise( chunks=chunks, in_names=in_names, extra_func_kwargs=extra_func_kwargs, + num_input_blocks=num_input_blocks, **kwargs, ) plan = Plan._new( @@ -1059,6 +1064,7 @@ def block_function(out_key): dtype=dtype, chunks=chunks, extra_projected_mem=extra_projected_mem, + num_input_blocks=(sum(split_every.values()),), reduce_func=func, initial_func=initial_func, axis=axis, diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 9b8deb28..4ea0b1bf 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -326,6 +326,10 @@ def visualize( tooltip += f"\ntasks: {primitive_op.num_tasks}" if primitive_op.write_chunks is not None: tooltip += f"\nwrite chunks: {primitive_op.write_chunks}" + if primitive_op.num_input_blocks is not None: + tooltip += ( + f"\nnum input blocks: {primitive_op.num_input_blocks}" + ) del d["primitive_op"] # remove pipeline attribute since it is a long string that causes graphviz to fail diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 5ea1a623..3846c18f 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -119,6 +119,7 @@ def blockwise( extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, fusable: bool = True, + num_input_blocks: Optional[Tuple[int, ...]] = None, **kwargs, ): """Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules. @@ -201,6 +202,7 @@ def blockwise( extra_projected_mem=extra_projected_mem, extra_func_kwargs=extra_func_kwargs, fusable=fusable, + num_input_blocks=num_input_blocks, **kwargs, ) @@ -219,6 +221,7 @@ def general_blockwise( extra_projected_mem: int = 0, extra_func_kwargs: Optional[Dict[str, Any]] = None, fusable: bool = True, + num_input_blocks: Optional[Tuple[int, ...]] = None, **kwargs, ): """A more general form of ``blockwise`` that uses a function to specify the block @@ -317,6 +320,7 @@ def general_blockwise( reserved_mem=reserved_mem, num_tasks=num_tasks, fusable=fusable, + num_input_blocks=num_input_blocks, ) @@ -399,6 +403,9 @@ def fused_func(*args): allowed_mem = primitive_op2.allowed_mem reserved_mem = primitive_op2.reserved_mem num_tasks = primitive_op2.num_tasks + num_input_blocks = tuple( + n * primitive_op2.num_input_blocks[0] for n in primitive_op1.num_input_blocks + ) pipeline = CubedPipeline( apply_blockwise, @@ -414,6 +421,7 @@ def fused_func(*args): reserved_mem=reserved_mem, num_tasks=num_tasks, fusable=True, + num_input_blocks=num_input_blocks, ) @@ -490,6 +498,12 @@ def fused_func(*args): allowed_mem = primitive_op.allowed_mem reserved_mem = primitive_op.reserved_mem num_tasks = primitive_op.num_tasks + tmp = [ + p.num_input_blocks if p is not None else (1,) for p in predecessor_primitive_ops + ] + num_input_blocks = tuple( + primitive_op.num_input_blocks[0] * n for n in itertools.chain(*tmp) + ) fused_pipeline = CubedPipeline( apply_blockwise, @@ -505,6 +519,7 @@ def fused_func(*args): reserved_mem=reserved_mem, num_tasks=num_tasks, fusable=True, + num_input_blocks=num_input_blocks, ) diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index c9a3316c..ecf96be3 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, Tuple import zarr @@ -37,6 +37,9 @@ class PrimitiveOperation: fusable: bool = True """Whether this operation should be considered for fusion.""" + num_input_blocks: Optional[Tuple[int, ...]] = None + """The number of input blocks read from each input array.""" + write_chunks: Optional[T_RegularChunks] = None """The chunk size used by this operation.""" From 726bcf8a0ea72ffe97545b8983baa5d71041d69d Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 23 Jan 2024 12:15:55 +0000 Subject: [PATCH 2/8] Implement merge_chunks_new using general_blockwise --- cubed/core/ops.py | 61 ++++++++++++++++++++++++++++++++++++++++++-- cubed/utils.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index 274b8ef9..d8bddc5b 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -27,10 +27,16 @@ from cubed.primitive.blockwise import blockwise as primitive_blockwise from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise from cubed.primitive.rechunk import rechunk as primitive_rechunk -from cubed.utils import chunk_memory, get_item, offset_to_block_id, to_chunksize +from cubed.utils import ( + _concatenate2, + chunk_memory, + get_item, + offset_to_block_id, + to_chunksize, +) from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks from cubed.vendor.dask.array.utils import validate_axis -from cubed.vendor.dask.blockwise import broadcast_dimensions +from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product from cubed.vendor.dask.utils import has_keyword if TYPE_CHECKING: @@ -764,6 +770,7 @@ def rechunk(x, chunks, target_store=None): def merge_chunks(x, chunks): + """Merge multiple chunks into one.""" target_chunksize = chunks if len(target_chunksize) != x.ndim: raise ValueError( @@ -792,6 +799,56 @@ def _copy_chunk(e, x, target_chunks=None, block_id=None): return out +def merge_chunks_new(x, chunks): + # new implementation that uses general_blockwise rather than map_direct + target_chunksize = chunks + if len(target_chunksize) != x.ndim: + raise ValueError( + f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})" + ) + if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)): + raise ValueError( + f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}" + ) + + target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype) + axes = [ + i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1 + ] + + def block_function(out_key): + out_coords = out_key[1:] + + in_keys = [] + for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)): + k = c1 // c0 # number of blocks to merge in axis i + if k == 1: + in_keys.append(out_coords[i]) + else: + start = out_coords[i] * k + stop = min(start + k, x.numblocks[i]) + in_keys.append(list(range(start, stop))) + + # return a tuple with a single item that is the list of input keys to be merged + return (lol_product((x.name,), in_keys),) + + num_input_blocks = int( + np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)]) + ) + + return general_blockwise( + _concatenate2, + block_function, + x, + shape=x.shape, + dtype=x.dtype, + chunks=target_chunks, + extra_projected_mem=0, + num_input_blocks=(num_input_blocks,), + axes=axes, + ) + + def reduction( x: "Array", func, diff --git a/cubed/utils.py b/cubed/utils.py index ff36192c..84b1cfc3 100644 --- a/cubed/utils.py +++ b/cubed/utils.py @@ -310,3 +310,68 @@ def broadcast_trick(func): inner.__doc__ = func.__doc__ inner.__name__ = func.__name__ return inner + + +# From dask.array.core, but changed to use nxp namespace +def _concatenate2(arrays, axes=None): + """Recursively concatenate nested lists of arrays along axes + + Each entry in axes corresponds to each level of the nested list. The + length of axes should correspond to the level of nesting of arrays. + If axes is an empty list or tuple, return arrays, or arrays[0] if + arrays is a list. + + >>> x = np.array([[1, 2], [3, 4]]) + >>> _concatenate2([x, x], axes=[0]) + array([[1, 2], + [3, 4], + [1, 2], + [3, 4]]) + + >>> _concatenate2([x, x], axes=[1]) + array([[1, 2, 1, 2], + [3, 4, 3, 4]]) + + >>> _concatenate2([[x, x], [x, x]], axes=[0, 1]) + array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]]) + + Supports Iterators + >>> _concatenate2(iter([x, x]), axes=[1]) + array([[1, 2, 1, 2], + [3, 4, 3, 4]]) + + Special Case + >>> _concatenate2([x, x], axes=()) + array([[1, 2], + [3, 4]]) + """ + if axes is None: + axes = [] + + if axes == (): + if isinstance(arrays, list): + return arrays[0] + else: + return arrays + + if isinstance(arrays, Iterator): + arrays = list(arrays) + if not isinstance(arrays, (list, tuple)): + return arrays + if len(axes) > 1: + arrays = [_concatenate2(a, axes=axes[1:]) for a in arrays] + concatenate = nxp.concat + if isinstance(arrays[0], dict): + # Handle concatenation of `dict`s, used as a replacement for structured + # arrays when that's not supported by the array library (e.g., CuPy). + keys = list(arrays[0].keys()) + assert all(list(a.keys()) == keys for a in arrays) + ret = dict() + for k in keys: + ret[k] = concatenate(list(a[k] for a in arrays), axis=axes[0]) + return ret + else: + return concatenate(arrays, axis=axes[0]) From 5b861ff859a673a6b630e6848495a2cd12540940 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 22 Jan 2024 14:25:35 +0000 Subject: [PATCH 3/8] Add failing test for merge_chunks_new --- cubed/tests/test_optimization.py | 105 ++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 3 deletions(-) diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index d13a8c35..f414becf 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -8,7 +8,7 @@ import cubed import cubed.array_api as xp from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import elemwise +from cubed.core.ops import elemwise, merge_chunks_new from cubed.core.optimization import ( fuse_all_optimize_dag, fuse_only_optimize_dag, @@ -138,9 +138,13 @@ def custom_optimize_function(dag): ) -def fuse_one_level(arr): +def fuse_one_level(arr, *, always_fuse=None): # use fuse_predecessors to test one level of fusion - return partial(fuse_predecessors, name=next(arr.plan.dag.predecessors(arr.name))) + return partial( + fuse_predecessors, + name=next(arr.plan.dag.predecessors(arr.name)), + always_fuse=always_fuse, + ) def fuse_multiple_levels(*, max_total_source_arrays=4): @@ -681,6 +685,101 @@ def test_fuse_large_fan_in_override(spec): assert_array_equal(result, 8 * np.ones((2, 2))) +# merge chunks with same number of tasks +# +# a -> a +# | | +# b c +# | +# c +# +def test_fuse_with_merge_chunks_unary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = merge_chunks_new(a, chunks=(3, 2)) + c = xp.negative(b) + + opt_fn = fuse_one_level(c) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (c,)) + assert structurally_equivalent( + c.plan.optimize(optimize_function=opt_fn).dag, + expected_fused_dag, + ) + + result = c.compute(optimize_function=opt_fn) + assert_array_equal(result, -np.ones((3, 2))) + + +# merge chunks with different number of tasks (b has more tasks than c) +# +# a -> a +# | | +# b c +# | +# c +# +def test_fuse_merge_chunks_unary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.negative(a) + c = merge_chunks_new(b, chunks=(3, 2)) + + # force c to fuse + last_op = sorted(c.plan.dag.nodes())[-1] + opt_fn = fuse_one_level(c, always_fuse=[last_op]) + + c.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (a,), (c,)) + assert structurally_equivalent( + c.plan.optimize(optimize_function=opt_fn).dag, + expected_fused_dag, + ) + + result = c.compute(optimize_function=opt_fn) + assert_array_equal(result, -np.ones((3, 2))) + + +# merge chunks with different number of tasks (c has more tasks than d) +# +# a b -> a b +# \ / \ / +# c d +# | +# d +# +def test_fuse_merge_chunks_binary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.ones((3, 2), chunks=(1, 2), spec=spec) + c = xp.add(a, b) + d = merge_chunks_new(c, chunks=(3, 2)) + + # force d to fuse + last_op = sorted(d.plan.dag.nodes())[-1] + opt_fn = fuse_one_level(d, always_fuse=[last_op]) + + d.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (d,)) + assert structurally_equivalent( + d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag + ) + + result = d.compute(optimize_function=opt_fn) + assert_array_equal(result, 2 * np.ones((3, 2))) + + def test_fuse_only_optimize_dag(spec): a = xp.ones((2, 2), chunks=(2, 2), spec=spec) b = xp.negative(a) From de9d176c45a33c27185cda8268f9f410acfafb34 Mon Sep 17 00:00:00 2001 From: Tom White Date: Thu, 18 Jan 2024 16:35:45 +0000 Subject: [PATCH 4/8] Don't assert that num tasks must be the same in fuse_multiple --- cubed/primitive/blockwise.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 3846c18f..b8b94598 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -432,12 +432,6 @@ def fuse_multiple( Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations. """ - assert all( - primitive_op.num_tasks == p.num_tasks - for p in predecessor_primitive_ops - if p is not None - ) - pipeline = primitive_op.pipeline predecessor_pipelines = [ primitive_op.pipeline if primitive_op is not None else None From ddd340c00b5b1c3847b6f683a154370baac4b13e Mon Sep 17 00:00:00 2001 From: Tom White Date: Sat, 13 Jan 2024 15:30:32 +0000 Subject: [PATCH 5/8] Fuse primitive ops with different numbers of tasks --- cubed/primitive/blockwise.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index b8b94598..65c10f2c 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -1,5 +1,6 @@ import itertools import math +from collections.abc import Iterator from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -446,10 +447,19 @@ def fuse_multiple( mappable = pipeline.mappable - def apply_pipeline_block_func(pipeline, arg): + def apply_pipeline_block_func(pipeline, n_input_blocks, arg): if pipeline is None: return (arg,) - return pipeline.config.block_function(arg) + if n_input_blocks == 1: + assert isinstance(arg, tuple) + return pipeline.config.block_function(arg) + else: + # more than one input block is being read from arg + assert isinstance(arg, (list, Iterator)) + return tuple( + list(item) + for item in zip(*(pipeline.config.block_function(a) for a in arg)) + ) def fused_blockwise_func(out_key): # this will change when multiple outputs are supported @@ -457,20 +467,28 @@ def fused_blockwise_func(out_key): # split all args to the fused function into groups, one for each predecessor function func_args = tuple( item - for p, a in zip(predecessor_pipelines, args) - for item in apply_pipeline_block_func(p, a) + for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) + for item in apply_pipeline_block_func( + p, primitive_op.num_input_blocks[i], a + ) ) return split_into(func_args, predecessor_funcs_nargs) - def apply_pipeline_func(pipeline, *args): + def apply_pipeline_func(pipeline, n_input_blocks, *args): if pipeline is None: return args[0] - return pipeline.config.function(*args) + if n_input_blocks == 1: + ret = pipeline.config.function(*args) + else: + # more than one input block is being read from this group of args to primitive op + ret = [pipeline.config.function(*item) for item in list(zip(*args))] + return ret def fused_func(*args): # args are grouped appropriately so they can be called by each predecessor function func_args = [ - apply_pipeline_func(p, *a) for p, a in zip(predecessor_pipelines, args) + apply_pipeline_func(p, primitive_op.num_input_blocks[i], *a) + for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) ] return pipeline.config.function(*func_args) From 9a7c3ed9e8ee726e86806bfd0a4099775fcb2a65 Mon Sep 17 00:00:00 2001 From: Tom White Date: Mon, 29 Jan 2024 16:20:01 +0000 Subject: [PATCH 6/8] Improve quad means test --- cubed/array_api/statistical_functions.py | 3 ++- cubed/tests/test_core.py | 15 ++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/cubed/array_api/statistical_functions.py b/cubed/array_api/statistical_functions.py index 53022cb2..21a50d2c 100644 --- a/cubed/array_api/statistical_functions.py +++ b/cubed/array_api/statistical_functions.py @@ -25,7 +25,7 @@ def max(x, /, *, axis=None, keepdims=False): return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims) -def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False): +def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=None): if x.dtype not in _real_floating_dtypes: raise TypeError("Only real floating-point dtypes are allowed in mean") # This implementation uses NumPy and Zarr's structured arrays to store a @@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False): dtype=dtype, keepdims=keepdims, use_new_impl=use_new_impl, + split_every=split_every, extra_func_kwargs=extra_func_kwargs, ) diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 94071ad5..ea62fdaf 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -12,7 +12,7 @@ import cubed.random from cubed.backend_array_api import namespace as nxp from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce -from cubed.core.optimization import fuse_all_optimize_dag +from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag from cubed.tests.utils import ( ALL_EXECUTORS, MAIN_EXECUTORS, @@ -531,10 +531,19 @@ def test_plan_quad_means(tmp_path, t_length): u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec) uv = u * v - m = xp.mean(uv, axis=0) + m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True) assert m.plan.num_tasks() > 0 - m.visualize(filename=tmp_path / "quad_means") + m.visualize( + filename=tmp_path / "quad_means_unoptimized", + optimize_graph=False, + show_hidden=True, + ) + m.visualize( + filename=tmp_path / "quad_means", + optimize_function=multiple_inputs_optimize_dag, + show_hidden=True, + ) def quad_means(tmp_path, t_length): From 263d89c7cf510dfd5fa0722dd750e6c6c5da2c2f Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 31 Jan 2024 09:57:12 +0000 Subject: [PATCH 7/8] Move 'num_input_blocks' to blockwise spec Add tests to check num_input_blocks --- cubed/core/ops.py | 4 +- cubed/core/plan.py | 9 ++-- cubed/primitive/blockwise.py | 62 ++++++++++++++++++++-------- cubed/primitive/types.py | 5 +-- cubed/tests/test_optimization.py | 70 ++++++++++++++++++++++++++++++++ 5 files changed, 122 insertions(+), 28 deletions(-) diff --git a/cubed/core/ops.py b/cubed/core/ops.py index d8bddc5b..41db51dd 100644 --- a/cubed/core/ops.py +++ b/cubed/core/ops.py @@ -272,7 +272,7 @@ def blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) fusable = kwargs.pop("fusable", True) - num_input_blocks = kwargs.pop("num_input_blocks", (1,) * len(source_arrays)) + num_input_blocks = kwargs.pop("num_input_blocks", None) name = gensym() spec = check_array_specs(arrays) @@ -332,7 +332,7 @@ def general_blockwise( extra_projected_mem = kwargs.pop("extra_projected_mem", 0) - num_input_blocks = kwargs.pop("num_input_blocks", (1,) * len(source_arrays)) + num_input_blocks = kwargs.pop("num_input_blocks", None) name = gensym() spec = check_array_specs(arrays) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 4ea0b1bf..12b7ad56 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -326,14 +326,15 @@ def visualize( tooltip += f"\ntasks: {primitive_op.num_tasks}" if primitive_op.write_chunks is not None: tooltip += f"\nwrite chunks: {primitive_op.write_chunks}" - if primitive_op.num_input_blocks is not None: - tooltip += ( - f"\nnum input blocks: {primitive_op.num_input_blocks}" - ) del d["primitive_op"] # remove pipeline attribute since it is a long string that causes graphviz to fail if "pipeline" in d: + pipeline = d["pipeline"] + if pipeline.config is not None: + tooltip += ( + f"\nnum input blocks: {pipeline.config.num_input_blocks}" + ) del d["pipeline"] if "stack_summaries" in d and d["stack_summaries"] is not None: diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index 65c10f2c..efa3e7d9 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -46,6 +46,8 @@ class BlockwiseSpec: A function that maps input chunks to an output chunk. function_nargs: int The number of array arguments that ``function`` takes. + num_input_blocks: Tuple[int, ...] + The number of input blocks read from each input array. reads_map : Dict[str, CubedArrayProxy] Read proxy dictionary keyed by array name. write : CubedArrayProxy @@ -55,6 +57,7 @@ class BlockwiseSpec: block_function: Callable[..., Any] function: Callable[..., Any] function_nargs: int + num_input_blocks: Tuple[int, ...] reads_map: Dict[str, CubedArrayProxy] write: CubedArrayProxy @@ -275,12 +278,18 @@ def general_blockwise( func_kwargs = extra_func_kwargs or {} func_with_kwargs = partial(func, **{**kwargs, **func_kwargs}) + num_input_blocks = num_input_blocks or (1,) * len(arrays) read_proxies = { name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items() } write_proxy = CubedArrayProxy(target_array, chunksize) spec = BlockwiseSpec( - block_function, func_with_kwargs, len(arrays), read_proxies, write_proxy + block_function, + func_with_kwargs, + len(arrays), + num_input_blocks, + read_proxies, + write_proxy, ) # calculate projected memory @@ -321,7 +330,6 @@ def general_blockwise( reserved_mem=reserved_mem, num_tasks=num_tasks, fusable=fusable, - num_input_blocks=num_input_blocks, ) @@ -353,6 +361,11 @@ def can_fuse_multiple_primitive_ops( # larger than allowed_mem then we can't fuse if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem: return False + # if the number of input blocks for each input is not uniform, then we can't fuse + # (this should never happen since all operations are currently uniform) + num_input_blocks = primitive_op.pipeline.config.num_input_blocks + if not all(num_input_blocks[0] == n for n in num_input_blocks): + return False return all( primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops ) @@ -395,8 +408,17 @@ def fused_func(*args): function_nargs = pipeline1.config.function_nargs read_proxies = pipeline1.config.reads_map write_proxy = pipeline2.config.write + num_input_blocks = tuple( + n * pipeline2.config.num_input_blocks[0] + for n in pipeline1.config.num_input_blocks + ) spec = BlockwiseSpec( - fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy + fused_blockwise_func, + fused_func, + function_nargs, + num_input_blocks, + read_proxies, + write_proxy, ) target_array = primitive_op2.target_array @@ -404,9 +426,6 @@ def fused_func(*args): allowed_mem = primitive_op2.allowed_mem reserved_mem = primitive_op2.reserved_mem num_tasks = primitive_op2.num_tasks - num_input_blocks = tuple( - n * primitive_op2.num_input_blocks[0] for n in primitive_op1.num_input_blocks - ) pipeline = CubedPipeline( apply_blockwise, @@ -422,7 +441,6 @@ def fused_func(*args): reserved_mem=reserved_mem, num_tasks=num_tasks, fusable=True, - num_input_blocks=num_input_blocks, ) @@ -469,7 +487,7 @@ def fused_blockwise_func(out_key): item for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) for item in apply_pipeline_block_func( - p, primitive_op.num_input_blocks[i], a + p, pipeline.config.num_input_blocks[i], a ) ) return split_into(func_args, predecessor_funcs_nargs) @@ -487,19 +505,34 @@ def apply_pipeline_func(pipeline, n_input_blocks, *args): def fused_func(*args): # args are grouped appropriately so they can be called by each predecessor function func_args = [ - apply_pipeline_func(p, primitive_op.num_input_blocks[i], *a) + apply_pipeline_func(p, pipeline.config.num_input_blocks[i], *a) for i, (p, a) in enumerate(zip(predecessor_pipelines, args)) ] return pipeline.config.function(*func_args) - function_nargs = pipeline.config.function_nargs + fused_function_nargs = pipeline.config.function_nargs + # ok to get num_input_blocks[0] since it is uniform (see check in can_fuse_multiple_primitive_ops) + fused_num_input_blocks = tuple( + pipeline.config.num_input_blocks[0] * n + for n in itertools.chain( + *( + p.pipeline.config.num_input_blocks if p is not None else (1,) + for p in predecessor_primitive_ops + ) + ) + ) read_proxies = dict(pipeline.config.reads_map) for p in predecessor_pipelines: if p is not None: read_proxies.update(p.config.reads_map) write_proxy = pipeline.config.write spec = BlockwiseSpec( - fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy + fused_blockwise_func, + fused_func, + fused_function_nargs, + fused_num_input_blocks, + read_proxies, + write_proxy, ) target_array = primitive_op.target_array @@ -510,12 +543,6 @@ def fused_func(*args): allowed_mem = primitive_op.allowed_mem reserved_mem = primitive_op.reserved_mem num_tasks = primitive_op.num_tasks - tmp = [ - p.num_input_blocks if p is not None else (1,) for p in predecessor_primitive_ops - ] - num_input_blocks = tuple( - primitive_op.num_input_blocks[0] * n for n in itertools.chain(*tmp) - ) fused_pipeline = CubedPipeline( apply_blockwise, @@ -531,7 +558,6 @@ def fused_func(*args): reserved_mem=reserved_mem, num_tasks=num_tasks, fusable=True, - num_input_blocks=num_input_blocks, ) diff --git a/cubed/primitive/types.py b/cubed/primitive/types.py index ecf96be3..c9a3316c 100644 --- a/cubed/primitive/types.py +++ b/cubed/primitive/types.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any, Optional import zarr @@ -37,9 +37,6 @@ class PrimitiveOperation: fusable: bool = True """Whether this operation should be considered for fusion.""" - num_input_blocks: Optional[Tuple[int, ...]] = None - """The number of input blocks read from each input array.""" - write_chunks: Optional[T_RegularChunks] = None """The chunk size used by this operation.""" diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index f414becf..a9096b14 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -138,6 +138,11 @@ def custom_optimize_function(dag): ) +def get_num_input_blocks(dag, arr_name): + op_name = next(dag.predecessors(arr_name)) + return dag.nodes(data=True)[op_name]["pipeline"].config.num_input_blocks + + def fuse_one_level(arr, *, always_fuse=None): # use fuse_predecessors to test one level of fusion return partial( @@ -232,6 +237,10 @@ def test_fuse_unary_op(spec): c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag, ) + assert get_num_input_blocks(c.plan.dag, c.name) == (1,) + assert get_num_input_blocks( + c.plan.optimize(optimize_function=opt_fn).dag, c.name + ) == (1,) num_created_arrays = 2 # b, c assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 @@ -272,6 +281,10 @@ def test_fuse_binary_op(spec): assert structurally_equivalent( e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks( + e.plan.optimize(optimize_function=opt_fn).dag, e.name + ) == (1, 1) num_created_arrays = 3 # c, d, e assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 @@ -314,6 +327,10 @@ def test_fuse_unary_and_binary_op(spec): assert structurally_equivalent( f.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(f.plan.dag, f.name) == (1, 1) + assert get_num_input_blocks( + f.plan.optimize(optimize_function=opt_fn).dag, f.name + ) == (1, 1, 1) result = f.compute(optimize_function=opt_fn) assert_array_equal(result, np.ones((2, 2))) @@ -347,6 +364,10 @@ def test_fuse_mixed_levels(spec): assert structurally_equivalent( e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks( + e.plan.optimize(optimize_function=opt_fn).dag, e.name + ) == (1, 1, 1) result = e.compute(optimize_function=opt_fn) assert_array_equal(result, 3 * np.ones((2, 2))) @@ -377,6 +398,10 @@ def test_fuse_diamond(spec): assert structurally_equivalent( d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) + assert get_num_input_blocks( + d.plan.optimize(optimize_function=opt_fn).dag, d.name + ) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -411,6 +436,10 @@ def test_fuse_mixed_levels_and_diamond(spec): assert structurally_equivalent( d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) + assert get_num_input_blocks( + d.plan.optimize(optimize_function=opt_fn).dag, d.name + ) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -441,6 +470,10 @@ def test_fuse_repeated_argument(spec): assert structurally_equivalent( c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) + assert get_num_input_blocks( + c.plan.optimize(optimize_function=opt_fn).dag, c.name + ) == (1, 1) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -2 * np.ones((2, 2))) @@ -476,6 +509,10 @@ def test_fuse_other_dependents(spec): assert structurally_equivalent( plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(c.plan.dag, c.name) == (1,) + assert get_num_input_blocks( + c.plan.optimize(optimize_function=opt_fn).dag, c.name + ) == (1,) c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn) assert_array_equal(c_result, np.ones((2, 2))) @@ -542,6 +579,11 @@ def stack_add(*a): assert structurally_equivalent( j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(j.plan.dag, j.name) == (1,) + assert ( + get_num_input_blocks(j.plan.optimize(optimize_function=opt_fn).dag, j.name) + == (1,) * 8 + ) result = j.compute(optimize_function=opt_fn) assert_array_equal(result, -8 * np.ones((2, 2))) @@ -600,6 +642,10 @@ def test_fuse_large_fan_in_default(spec): assert structurally_equivalent( p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) + assert get_num_input_blocks( + p.plan.optimize(optimize_function=opt_fn).dag, p.name + ) == (1, 1) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -669,6 +715,11 @@ def test_fuse_large_fan_in_override(spec): assert structurally_equivalent( p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) + assert ( + get_num_input_blocks(p.plan.optimize(optimize_function=opt_fn).dag, p.name) + == (1,) * 8 + ) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -710,6 +761,11 @@ def test_fuse_with_merge_chunks_unary(spec): c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag, ) + assert get_num_input_blocks(b.plan.dag, b.name) == (3,) + assert get_num_input_blocks(c.plan.dag, c.name) == (1,) + assert get_num_input_blocks( + c.plan.optimize(optimize_function=opt_fn).dag, c.name + ) == (3,) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) @@ -742,6 +798,11 @@ def test_fuse_merge_chunks_unary(spec): c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag, ) + assert get_num_input_blocks(b.plan.dag, b.name) == (1,) + assert get_num_input_blocks(c.plan.dag, c.name) == (3,) + assert get_num_input_blocks( + c.plan.optimize(optimize_function=opt_fn).dag, c.name + ) == (3,) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) @@ -775,6 +836,11 @@ def test_fuse_merge_chunks_binary(spec): assert structurally_equivalent( d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag ) + assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) + assert get_num_input_blocks(d.plan.dag, d.name) == (3,) + assert get_num_input_blocks( + d.plan.optimize(optimize_function=opt_fn).dag, d.name + ) == (3, 3) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((3, 2))) @@ -802,6 +868,10 @@ def test_fuse_only_optimize_dag(spec): d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag, ) + assert get_num_input_blocks(d.plan.dag, d.name) == (1,) + assert get_num_input_blocks( + d.plan.optimize(optimize_function=opt_fn).dag, d.name + ) == (1,) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((2, 2))) From e54aa1ba60218b2447220dcdc0621f8c03025b30 Mon Sep 17 00:00:00 2001 From: Tom White Date: Wed, 31 Jan 2024 17:43:10 +0000 Subject: [PATCH 8/8] Add another test, and improve code formatting --- cubed/primitive/blockwise.py | 10 +- cubed/tests/test_optimization.py | 194 ++++++++++++++----------------- 2 files changed, 93 insertions(+), 111 deletions(-) diff --git a/cubed/primitive/blockwise.py b/cubed/primitive/blockwise.py index efa3e7d9..14238278 100644 --- a/cubed/primitive/blockwise.py +++ b/cubed/primitive/blockwise.py @@ -357,12 +357,14 @@ def can_fuse_multiple_primitive_ops( if is_fuse_candidate(primitive_op) and all( is_fuse_candidate(p) for p in predecessor_primitive_ops ): - # if the peak projected memory for running all the predecessor ops in order is - # larger than allowed_mem then we can't fuse + # If the peak projected memory for running all the predecessor ops in + # order is larger than allowed_mem then we can't fuse. if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem: return False - # if the number of input blocks for each input is not uniform, then we can't fuse - # (this should never happen since all operations are currently uniform) + # If the number of input blocks for each input is not uniform, then we + # can't fuse. (This should never happen since all operations are + # currently uniform, and fused operations are too if fuse is applied in + # topological order.) num_input_blocks = primitive_op.pipeline.config.num_input_blocks if not all(num_input_blocks[0] == n for n in num_input_blocks): return False diff --git a/cubed/tests/test_optimization.py b/cubed/tests/test_optimization.py index a9096b14..4cf69b01 100644 --- a/cubed/tests/test_optimization.py +++ b/cubed/tests/test_optimization.py @@ -233,14 +233,10 @@ def test_fuse_unary_op(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (1,) num_created_arrays = 2 # b, c assert c.plan.num_tasks(optimize_graph=False) == num_created_arrays + 2 @@ -278,13 +274,10 @@ def test_fuse_binary_op(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (e,)) - assert structurally_equivalent( - e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) - assert get_num_input_blocks( - e.plan.optimize(optimize_function=opt_fn).dag, e.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (1, 1) num_created_arrays = 3 # c, d, e assert e.plan.num_tasks(optimize_graph=False) == num_created_arrays + 3 @@ -324,13 +317,10 @@ def test_fuse_unary_and_binary_op(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (), (c,)) add_placeholder_op(expected_fused_dag, (a, b, c), (f,)) - assert structurally_equivalent( - f.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = f.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(f.plan.dag, f.name) == (1, 1) - assert get_num_input_blocks( - f.plan.optimize(optimize_function=opt_fn).dag, f.name - ) == (1, 1, 1) + assert get_num_input_blocks(optimized_dag, f.name) == (1, 1, 1) result = f.compute(optimize_function=opt_fn) assert_array_equal(result, np.ones((2, 2))) @@ -361,13 +351,10 @@ def test_fuse_mixed_levels(spec): add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (), (c,)) add_placeholder_op(expected_fused_dag, (a, b, c), (e,)) - assert structurally_equivalent( - e.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) - assert get_num_input_blocks( - e.plan.optimize(optimize_function=opt_fn).dag, e.name - ) == (1, 1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (1, 1, 1) result = e.compute(optimize_function=opt_fn) assert_array_equal(result, 3 * np.ones((2, 2))) @@ -395,13 +382,10 @@ def test_fuse_diamond(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a, a), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, d.name) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -433,13 +417,10 @@ def test_fuse_mixed_levels_and_diamond(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, d.name) == (1, 1) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((2, 2))) @@ -467,13 +448,10 @@ def test_fuse_repeated_argument(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a, a), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, c.name) == (1, 1) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -2 * np.ones((2, 2))) @@ -506,13 +484,10 @@ def test_fuse_other_dependents(spec): add_placeholder_op(expected_fused_dag, (a,), (c,)) add_placeholder_op(expected_fused_dag, (b,), (d,)) plan = arrays_to_plan(c, d) - assert structurally_equivalent( - plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (1,) + assert get_num_input_blocks(optimized_dag, c.name) == (1,) c_result, d_result = cubed.compute(c, d, optimize_function=opt_fn) assert_array_equal(c_result, np.ones((2, 2))) @@ -576,14 +551,10 @@ def stack_add(*a): ), (j,), ) - assert structurally_equivalent( - j.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = j.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(j.plan.dag, j.name) == (1,) - assert ( - get_num_input_blocks(j.plan.optimize(optimize_function=opt_fn).dag, j.name) - == (1,) * 8 - ) + assert get_num_input_blocks(optimized_dag, j.name) == (1,) * 8 result = j.compute(optimize_function=opt_fn) assert_array_equal(result, -8 * np.ones((2, 2))) @@ -639,13 +610,10 @@ def test_fuse_large_fan_in_default(spec): add_placeholder_op(expected_fused_dag, (a, b, c, d), (n,)) add_placeholder_op(expected_fused_dag, (e, f, g, h), (o,)) add_placeholder_op(expected_fused_dag, (n, o), (p,)) - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) - assert get_num_input_blocks( - p.plan.optimize(optimize_function=opt_fn).dag, p.name - ) == (1, 1) + assert get_num_input_blocks(optimized_dag, p.name) == (1, 1) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -712,14 +680,10 @@ def test_fuse_large_fan_in_override(spec): ), (p,), ) - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(p.plan.dag, p.name) == (1, 1) - assert ( - get_num_input_blocks(p.plan.optimize(optimize_function=opt_fn).dag, p.name) - == (1,) * 8 - ) + assert get_num_input_blocks(optimized_dag, p.name) == (1,) * 8 result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) @@ -727,21 +691,19 @@ def test_fuse_large_fan_in_override(spec): # now force everything to be fused with fuse_all_optimize_dag # note that max_total_source_arrays is *not* set opt_fn = fuse_all_optimize_dag - - assert structurally_equivalent( - p.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = p.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) result = p.compute(optimize_function=opt_fn) assert_array_equal(result, 8 * np.ones((2, 2))) -# merge chunks with same number of tasks +# merge chunks with same number of tasks (unary) # # a -> a -# | | +# | 3 | 3 # b c -# | +# | 1 # c # def test_fuse_with_merge_chunks_unary(spec): @@ -757,26 +719,55 @@ def test_fuse_with_merge_chunks_unary(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(b.plan.dag, b.name) == (3,) assert get_num_input_blocks(c.plan.dag, c.name) == (1,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (3,) + assert get_num_input_blocks(optimized_dag, c.name) == (3,) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) +# merge chunks with same number of tasks (binary) +# +# a b -> a b +# 3 | | 1 3 \ / 1 +# c d e +# 1 \ / 1 +# e +# +def test_fuse_with_merge_chunks_binary(spec): + a = xp.ones((3, 2), chunks=(1, 2), spec=spec) + b = xp.ones((3, 2), chunks=(3, 2), spec=spec) + c = merge_chunks_new(a, chunks=(3, 2)) + d = xp.negative(b) + e = xp.add(c, d) + + opt_fn = fuse_one_level(e) + + e.visualize(optimize_function=opt_fn) + + # check structure of optimized dag + expected_fused_dag = create_dag() + add_placeholder_op(expected_fused_dag, (), (a,)) + add_placeholder_op(expected_fused_dag, (), (b,)) + add_placeholder_op(expected_fused_dag, (a, b), (e,)) + optimized_dag = e.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) + assert get_num_input_blocks(e.plan.dag, e.name) == (1, 1) + assert get_num_input_blocks(optimized_dag, e.name) == (3, 1) + + result = e.compute(optimize_function=opt_fn) + assert_array_equal(result, np.zeros((3, 2))) + + # merge chunks with different number of tasks (b has more tasks than c) # # a -> a -# | | +# | 1 | 3 # b c -# | +# | 3 # c # def test_fuse_merge_chunks_unary(spec): @@ -794,15 +785,11 @@ def test_fuse_merge_chunks_unary(spec): expected_fused_dag = create_dag() add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (c,)) - assert structurally_equivalent( - c.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(b.plan.dag, b.name) == (1,) assert get_num_input_blocks(c.plan.dag, c.name) == (3,) - assert get_num_input_blocks( - c.plan.optimize(optimize_function=opt_fn).dag, c.name - ) == (3,) + assert get_num_input_blocks(optimized_dag, c.name) == (3,) result = c.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((3, 2))) @@ -811,9 +798,9 @@ def test_fuse_merge_chunks_unary(spec): # merge chunks with different number of tasks (c has more tasks than d) # # a b -> a b -# \ / \ / +# 1 \ / 1 3 \ / 3 # c d -# | +# | 3 # d # def test_fuse_merge_chunks_binary(spec): @@ -833,14 +820,11 @@ def test_fuse_merge_chunks_binary(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (), (b,)) add_placeholder_op(expected_fused_dag, (a, b), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, expected_fused_dag - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(c.plan.dag, c.name) == (1, 1) assert get_num_input_blocks(d.plan.dag, d.name) == (3,) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (3, 3) + assert get_num_input_blocks(optimized_dag, d.name) == (3, 3) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, 2 * np.ones((3, 2))) @@ -864,14 +848,10 @@ def test_fuse_only_optimize_dag(spec): add_placeholder_op(expected_fused_dag, (), (a,)) add_placeholder_op(expected_fused_dag, (a,), (b,)) add_placeholder_op(expected_fused_dag, (b,), (d,)) - assert structurally_equivalent( - d.plan.optimize(optimize_function=opt_fn).dag, - expected_fused_dag, - ) + optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag + assert structurally_equivalent(optimized_dag, expected_fused_dag) assert get_num_input_blocks(d.plan.dag, d.name) == (1,) - assert get_num_input_blocks( - d.plan.optimize(optimize_function=opt_fn).dag, d.name - ) == (1,) + assert get_num_input_blocks(optimized_dag, d.name) == (1,) result = d.compute(optimize_function=opt_fn) assert_array_equal(result, -np.ones((2, 2)))