Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse operations with different numbers of tasks #368

Merged
merged 8 commits into from
Feb 5, 2024
3 changes: 2 additions & 1 deletion cubed/array_api/statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def max(x, /, *, axis=None, keepdims=False):
return reduction(x, nxp.max, axis=axis, dtype=x.dtype, keepdims=keepdims)


def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False, split_every=None):
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in mean")
# This implementation uses NumPy and Zarr's structured arrays to store a
Expand All @@ -47,6 +47,7 @@ def mean(x, /, *, axis=None, keepdims=False, use_new_impl=False):
dtype=dtype,
keepdims=keepdims,
use_new_impl=use_new_impl,
split_every=split_every,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this split_every argument a temporary implementation detail whilst we figure out fusing heuristics? It would be nice to relate the meaning of this argument back to the discussion in #284 (as I currently don't quite understand what it means)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split_every argument is what we referred to as "fan-in" in #284 - the number of chunks read by one task doing the reduction step. It's the same in Dask and can be a dictionary indicating the number of chunks to read in each dimension.

I hope it is something that we can get better heuristics for (or at least good defaults) - possibly by measuring trade offs like they did in the Primula paper, see #331.

extra_func_kwargs=extra_func_kwargs,
)

Expand Down
67 changes: 65 additions & 2 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,16 @@
from cubed.primitive.blockwise import blockwise as primitive_blockwise
from cubed.primitive.blockwise import general_blockwise as primitive_general_blockwise
from cubed.primitive.rechunk import rechunk as primitive_rechunk
from cubed.utils import chunk_memory, get_item, offset_to_block_id, to_chunksize
from cubed.utils import (
_concatenate2,
chunk_memory,
get_item,
offset_to_block_id,
to_chunksize,
)
from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks
from cubed.vendor.dask.array.utils import validate_axis
from cubed.vendor.dask.blockwise import broadcast_dimensions
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
from cubed.vendor.dask.utils import has_keyword

if TYPE_CHECKING:
Expand Down Expand Up @@ -266,6 +272,7 @@ def blockwise(
extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

fusable = kwargs.pop("fusable", True)
num_input_blocks = kwargs.pop("num_input_blocks", None)

name = gensym()
spec = check_array_specs(arrays)
Expand All @@ -287,6 +294,7 @@ def blockwise(
out_name=name,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
num_input_blocks=num_input_blocks,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -324,6 +332,8 @@ def general_blockwise(

extra_projected_mem = kwargs.pop("extra_projected_mem", 0)

num_input_blocks = kwargs.pop("num_input_blocks", None)

name = gensym()
spec = check_array_specs(arrays)
if target_store is None:
Expand All @@ -341,6 +351,7 @@ def general_blockwise(
chunks=chunks,
in_names=in_names,
extra_func_kwargs=extra_func_kwargs,
num_input_blocks=num_input_blocks,
**kwargs,
)
plan = Plan._new(
Expand Down Expand Up @@ -759,6 +770,7 @@ def rechunk(x, chunks, target_store=None):


def merge_chunks(x, chunks):
"""Merge multiple chunks into one."""
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
Expand Down Expand Up @@ -787,6 +799,56 @@ def _copy_chunk(e, x, target_chunks=None, block_id=None):
return out


def merge_chunks_new(x, chunks):
# new implementation that uses general_blockwise rather than map_direct
target_chunksize = chunks
if len(target_chunksize) != x.ndim:
raise ValueError(
f"Chunks {target_chunksize} must have same number of dimensions as array ({x.ndim})"
)
if not all(c1 % c0 == 0 for c0, c1 in zip(x.chunksize, target_chunksize)):
raise ValueError(
f"Chunks {target_chunksize} must be a multiple of array's chunks {x.chunksize}"
)

target_chunks = normalize_chunks(chunks, x.shape, dtype=x.dtype)
axes = [
i for (i, (c0, c1)) in enumerate(zip(x.chunksize, target_chunksize)) if c0 != c1
]

def block_function(out_key):
out_coords = out_key[1:]

in_keys = []
for i, (c0, c1) in enumerate(zip(x.chunksize, target_chunksize)):
k = c1 // c0 # number of blocks to merge in axis i
if k == 1:
in_keys.append(out_coords[i])
else:
start = out_coords[i] * k
stop = min(start + k, x.numblocks[i])
in_keys.append(list(range(start, stop)))

# return a tuple with a single item that is the list of input keys to be merged
return (lol_product((x.name,), in_keys),)

num_input_blocks = int(
np.prod([c1 // c0 for (c0, c1) in zip(x.chunksize, target_chunksize)])
)

return general_blockwise(
_concatenate2,
block_function,
x,
shape=x.shape,
dtype=x.dtype,
chunks=target_chunks,
extra_projected_mem=0,
num_input_blocks=(num_input_blocks,),
axes=axes,
)


def reduction(
x: "Array",
func,
Expand Down Expand Up @@ -1059,6 +1121,7 @@ def block_function(out_key):
dtype=dtype,
chunks=chunks,
extra_projected_mem=extra_projected_mem,
num_input_blocks=(sum(split_every.values()),),
reduce_func=func,
initial_func=initial_func,
axis=axis,
Expand Down
5 changes: 5 additions & 0 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,11 @@ def visualize(

# remove pipeline attribute since it is a long string that causes graphviz to fail
if "pipeline" in d:
pipeline = d["pipeline"]
if pipeline.config is not None:
tooltip += (
f"\nnum input blocks: {pipeline.config.num_input_blocks}"
)
del d["pipeline"]

if "stack_summaries" in d and d["stack_summaries"] is not None:
Expand Down
93 changes: 74 additions & 19 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import math
from collections.abc import Iterator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand Down Expand Up @@ -45,6 +46,8 @@ class BlockwiseSpec:
A function that maps input chunks to an output chunk.
function_nargs: int
The number of array arguments that ``function`` takes.
num_input_blocks: Tuple[int, ...]
The number of input blocks read from each input array.
reads_map : Dict[str, CubedArrayProxy]
Read proxy dictionary keyed by array name.
write : CubedArrayProxy
Expand All @@ -54,6 +57,7 @@ class BlockwiseSpec:
block_function: Callable[..., Any]
function: Callable[..., Any]
function_nargs: int
num_input_blocks: Tuple[int, ...]
reads_map: Dict[str, CubedArrayProxy]
write: CubedArrayProxy

Expand Down Expand Up @@ -119,6 +123,7 @@ def blockwise(
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
num_input_blocks: Optional[Tuple[int, ...]] = None,
**kwargs,
):
"""Apply a function to multiple blocks from multiple inputs, expressed using concise indexing rules.
Expand Down Expand Up @@ -201,6 +206,7 @@ def blockwise(
extra_projected_mem=extra_projected_mem,
extra_func_kwargs=extra_func_kwargs,
fusable=fusable,
num_input_blocks=num_input_blocks,
**kwargs,
)

Expand All @@ -219,6 +225,7 @@ def general_blockwise(
extra_projected_mem: int = 0,
extra_func_kwargs: Optional[Dict[str, Any]] = None,
fusable: bool = True,
num_input_blocks: Optional[Tuple[int, ...]] = None,
**kwargs,
):
"""A more general form of ``blockwise`` that uses a function to specify the block
Expand Down Expand Up @@ -271,12 +278,18 @@ def general_blockwise(

func_kwargs = extra_func_kwargs or {}
func_with_kwargs = partial(func, **{**kwargs, **func_kwargs})
num_input_blocks = num_input_blocks or (1,) * len(arrays)
read_proxies = {
name: CubedArrayProxy(array, array.chunks) for name, array in array_map.items()
}
write_proxy = CubedArrayProxy(target_array, chunksize)
spec = BlockwiseSpec(
block_function, func_with_kwargs, len(arrays), read_proxies, write_proxy
block_function,
func_with_kwargs,
len(arrays),
num_input_blocks,
read_proxies,
write_proxy,
)

# calculate projected memory
Expand Down Expand Up @@ -344,10 +357,17 @@ def can_fuse_multiple_primitive_ops(
if is_fuse_candidate(primitive_op) and all(
is_fuse_candidate(p) for p in predecessor_primitive_ops
):
# if the peak projected memory for running all the predecessor ops in order is
# larger than allowed_mem then we can't fuse
# If the peak projected memory for running all the predecessor ops in
# order is larger than allowed_mem then we can't fuse.
if peak_projected_mem(predecessor_primitive_ops) > primitive_op.allowed_mem:
return False
# If the number of input blocks for each input is not uniform, then we
# can't fuse. (This should never happen since all operations are
# currently uniform, and fused operations are too if fuse is applied in
# topological order.)
num_input_blocks = primitive_op.pipeline.config.num_input_blocks
if not all(num_input_blocks[0] == n for n in num_input_blocks):
return False
return all(
primitive_op.num_tasks == p.num_tasks for p in predecessor_primitive_ops
)
Expand Down Expand Up @@ -390,8 +410,17 @@ def fused_func(*args):
function_nargs = pipeline1.config.function_nargs
read_proxies = pipeline1.config.reads_map
write_proxy = pipeline2.config.write
num_input_blocks = tuple(
n * pipeline2.config.num_input_blocks[0]
for n in pipeline1.config.num_input_blocks
)
spec = BlockwiseSpec(
fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy
fused_blockwise_func,
fused_func,
function_nargs,
num_input_blocks,
read_proxies,
write_proxy,
)

target_array = primitive_op2.target_array
Expand Down Expand Up @@ -424,12 +453,6 @@ def fuse_multiple(
Fuse a blockwise operation and its predecessors into a single operation, avoiding writing to (or reading from) the targets of the predecessor operations.
"""

assert all(
primitive_op.num_tasks == p.num_tasks
for p in predecessor_primitive_ops
if p is not None
)

pipeline = primitive_op.pipeline
predecessor_pipelines = [
primitive_op.pipeline if primitive_op is not None else None
Expand All @@ -444,42 +467,74 @@ def fuse_multiple(

mappable = pipeline.mappable

def apply_pipeline_block_func(pipeline, arg):
def apply_pipeline_block_func(pipeline, n_input_blocks, arg):
if pipeline is None:
return (arg,)
return pipeline.config.block_function(arg)
if n_input_blocks == 1:
assert isinstance(arg, tuple)
return pipeline.config.block_function(arg)
else:
# more than one input block is being read from arg
assert isinstance(arg, (list, Iterator))
return tuple(
list(item)
for item in zip(*(pipeline.config.block_function(a) for a in arg))
)

def fused_blockwise_func(out_key):
# this will change when multiple outputs are supported
args = pipeline.config.block_function(out_key)
# split all args to the fused function into groups, one for each predecessor function
func_args = tuple(
item
for p, a in zip(predecessor_pipelines, args)
for item in apply_pipeline_block_func(p, a)
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
for item in apply_pipeline_block_func(
p, pipeline.config.num_input_blocks[i], a
)
)
return split_into(func_args, predecessor_funcs_nargs)

def apply_pipeline_func(pipeline, *args):
def apply_pipeline_func(pipeline, n_input_blocks, *args):
if pipeline is None:
return args[0]
return pipeline.config.function(*args)
if n_input_blocks == 1:
ret = pipeline.config.function(*args)
else:
# more than one input block is being read from this group of args to primitive op
ret = [pipeline.config.function(*item) for item in list(zip(*args))]
return ret

def fused_func(*args):
# args are grouped appropriately so they can be called by each predecessor function
func_args = [
apply_pipeline_func(p, *a) for p, a in zip(predecessor_pipelines, args)
apply_pipeline_func(p, pipeline.config.num_input_blocks[i], *a)
for i, (p, a) in enumerate(zip(predecessor_pipelines, args))
]
return pipeline.config.function(*func_args)

function_nargs = pipeline.config.function_nargs
fused_function_nargs = pipeline.config.function_nargs
# ok to get num_input_blocks[0] since it is uniform (see check in can_fuse_multiple_primitive_ops)
fused_num_input_blocks = tuple(
pipeline.config.num_input_blocks[0] * n
for n in itertools.chain(
*(
p.pipeline.config.num_input_blocks if p is not None else (1,)
for p in predecessor_primitive_ops
)
)
)
read_proxies = dict(pipeline.config.reads_map)
for p in predecessor_pipelines:
if p is not None:
read_proxies.update(p.config.reads_map)
write_proxy = pipeline.config.write
spec = BlockwiseSpec(
fused_blockwise_func, fused_func, function_nargs, read_proxies, write_proxy
fused_blockwise_func,
fused_func,
fused_function_nargs,
fused_num_input_blocks,
read_proxies,
write_proxy,
)

target_array = primitive_op.target_array
Expand Down
15 changes: 12 additions & 3 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cubed.random
from cubed.backend_array_api import namespace as nxp
from cubed.core.ops import merge_chunks, partial_reduce, tree_reduce
from cubed.core.optimization import fuse_all_optimize_dag
from cubed.core.optimization import fuse_all_optimize_dag, multiple_inputs_optimize_dag
from cubed.tests.utils import (
ALL_EXECUTORS,
MAIN_EXECUTORS,
Expand Down Expand Up @@ -531,10 +531,19 @@ def test_plan_quad_means(tmp_path, t_length):
u = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
v = cubed.random.random((t_length, 1, 987, 1920), chunks=(10, 1, -1, -1), spec=spec)
uv = u * v
m = xp.mean(uv, axis=0)
m = xp.mean(uv, axis=0, split_every=10, use_new_impl=True)

assert m.plan.num_tasks() > 0
m.visualize(filename=tmp_path / "quad_means")
m.visualize(
filename=tmp_path / "quad_means_unoptimized",
optimize_graph=False,
show_hidden=True,
)
m.visualize(
filename=tmp_path / "quad_means",
optimize_function=multiple_inputs_optimize_dag,
show_hidden=True,
)


def quad_means(tmp_path, t_length):
Expand Down
Loading
Loading