Skip to content

Commit

Permalink
Change default max_total_num_input_blocks to 10, to allow operations
Browse files Browse the repository at this point in the history
to be fused as long as the total number of input blocks does not exceed this.

Previously, the default was None, which meant operations were fused
only if they had the same number of tasks.
  • Loading branch information
tomwhite committed Nov 10, 2024
1 parent 09bb9de commit b7c6a0a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
15 changes: 9 additions & 6 deletions cubed/core/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@

logger = logging.getLogger(__name__)

DEFAULT_MAX_TOTAL_SOURCE_ARRAYS = 4
DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS = 10


def simple_optimize_dag(dag, array_names=None):
"""Apply map blocks fusion."""
Expand Down Expand Up @@ -154,8 +157,8 @@ def can_fuse_predecessors(
name,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
always_fuse=None,
never_fuse=None,
):
Expand Down Expand Up @@ -242,8 +245,8 @@ def fuse_predecessors(
name,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
always_fuse=None,
never_fuse=None,
):
Expand Down Expand Up @@ -297,8 +300,8 @@ def multiple_inputs_optimize_dag(
dag,
*,
array_names=None,
max_total_source_arrays=4,
max_total_num_input_blocks=None,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
always_fuse=None,
never_fuse=None,
):
Expand Down
44 changes: 33 additions & 11 deletions cubed/tests/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import elemwise, merge_chunks, partial_reduce
from cubed.core.optimization import (
DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
fuse_all_optimize_dag,
fuse_only_optimize_dag,
fuse_predecessors,
Expand Down Expand Up @@ -223,7 +225,11 @@ def fuse_one_level(arr, *, always_fuse=None):
)


def fuse_multiple_levels(*, max_total_source_arrays=4, max_total_num_input_blocks=None):
def fuse_multiple_levels(
*,
max_total_source_arrays=DEFAULT_MAX_TOTAL_SOURCE_ARRAYS,
max_total_num_input_blocks=DEFAULT_MAX_TOTAL_NUM_INPUT_BLOCKS,
):
# use multiple_inputs_optimize_dag to test multiple levels of fusion
return partial(
multiple_inputs_optimize_dag,
Expand Down Expand Up @@ -899,8 +905,7 @@ def test_fuse_merge_chunks_unary(spec):
b = xp.negative(a)
c = merge_chunks(b, chunks=(3, 2))

# specify max_total_num_input_blocks to force c to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)
opt_fn = fuse_multiple_levels()

c.visualize(optimize_function=opt_fn)

Expand All @@ -921,6 +926,16 @@ def test_fuse_merge_chunks_unary(spec):
result = c.compute(optimize_function=opt_fn)
assert_array_equal(result, -np.ones((3, 2)))

# now set max_total_num_input_blocks=None which means
# "only fuse if ops have same number of tasks", which they don't here
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None)
optimized_dag = c.plan.optimize(optimize_function=opt_fn).dag

# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
assert not structurally_equivalent(
optimized_dag, expected_fused_dag, remove_hidden=True
)


# merge chunks with different number of tasks (c has more tasks than d)
#
Expand All @@ -936,8 +951,7 @@ def test_fuse_merge_chunks_binary(spec):
c = xp.add(a, b)
d = merge_chunks(c, chunks=(3, 2))

# specify max_total_num_input_blocks to force d to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)
opt_fn = fuse_multiple_levels()

d.visualize(optimize_function=opt_fn)

Expand All @@ -963,15 +977,24 @@ def test_fuse_merge_chunks_binary(spec):
result = d.compute(optimize_function=opt_fn)
assert_array_equal(result, 2 * np.ones((3, 2)))

# now set max_total_num_input_blocks=None which means
# "only fuse if ops have same number of tasks", which they don't here
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=None)
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag

# merge_chunks uses a hidden op and array for block ids - ignore when comparing structure
assert not structurally_equivalent(
optimized_dag, expected_fused_dag, remove_hidden=True
)


# like test_fuse_merge_chunks_unary, except uses partial_reduce
def test_fuse_partial_reduce_unary(spec):
a = xp.ones((3, 2), chunks=(1, 2), spec=spec)
b = xp.negative(a)
c = partial_reduce(b, np.sum, split_every={0: 3})

# specify max_total_num_input_blocks to force c to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=3)
opt_fn = fuse_multiple_levels()

c.visualize(optimize_function=opt_fn)

Expand All @@ -996,8 +1019,7 @@ def test_fuse_partial_reduce_binary(spec):
c = xp.add(a, b)
d = partial_reduce(c, np.sum, split_every={0: 3})

# specify max_total_num_input_blocks to force d to fuse
opt_fn = fuse_multiple_levels(max_total_num_input_blocks=6)
opt_fn = fuse_multiple_levels()

d.visualize(optimize_function=opt_fn)

Expand Down Expand Up @@ -1176,7 +1198,7 @@ def test_optimize_stack(spec):
c = xp.stack((a, b), axis=0)
d = c + 1
# try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b)
d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10))
d.compute(optimize_function=fuse_multiple_levels())


def test_optimize_concat(spec):
Expand All @@ -1186,4 +1208,4 @@ def test_optimize_concat(spec):
c = xp.concat((a, b), axis=0)
d = c + 1
# try to fuse all ops into one (d will fuse with c, but c won't fuse with a and b)
d.compute(optimize_function=fuse_multiple_levels(max_total_num_input_blocks=10))
d.compute(optimize_function=fuse_multiple_levels())
4 changes: 2 additions & 2 deletions docs/user-guide/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ e.visualize(optimize_function=opt_fn)

The `max_total_num_input_blocks` argument to `multiple_inputs_optimize_dag` specifies the maximum number of input blocks (chunks) that are allowed in the fused operation.

Again, this is to limit the number of reads that an individual task must perform. The default is `None`, which means that operations are fused only if they have the same number of tasks. If set to an integer, then this limitation is removed, and tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be set using `functools.partial`:
Again, this is to limit the number of reads that an individual task must perform. If set to `None`, operations are fused only if they have the same number of tasks. If set to an integer (the default is 10), then tasks with a different number of tasks will be fused - as long as the total number of input blocks does not exceed the maximum. This setting is useful for reductions, and can be changed using `functools.partial`:

```python
opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=10)
opt_fn = partial(multiple_inputs_optimize_dag, max_total_num_input_blocks=20)
e.visualize(optimize_function=opt_fn)
```

0 comments on commit b7c6a0a

Please sign in to comment.