Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower subarray patterns to new grid mapped array nodes #125

Open
wants to merge 1 commit into
base: spr/orausch/master.lower-subarray-patterns-to-new-grid-mapped-array-nodes
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions daceml/distributed/communication/grid_mapped_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,12 @@ def expansion(node: GridMapper, state: SDFGState, sdfg: SDFG):
replication_grid_name = None

pgrid_desc = sdfg.process_grids[partition_grid_name]
unique_id = "{}_{}_{}_{}".format("Scatter" if scatter else "Gather",
sdfg.sdfg_id, sdfg.node_id(state),
state.node_id(node))

in_name = state.in_edges(node)[0].data.data
out_name = state.in_edges(node)[0].data.data
unique_id = "{}_{}_{}_{}_{}_{}".format(
"Scatter" if scatter else "Gather", sdfg.sdfg_id,
sdfg.node_id(state), state.node_id(node), in_name, out_name)
mpi_dtype = mpi.utils.MPI_DDT(out_desc.dtype.base_type)

array_desc = inp_desc if scatter else out_desc
Expand Down
260 changes: 61 additions & 199 deletions daceml/distributed/communication/subarrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from .. import utils as distr_utils

from . import grid_mapped_array as grid_array

if typing.TYPE_CHECKING:
from .node import DistributedMemlet

Expand Down Expand Up @@ -83,66 +85,12 @@ def match_subset_axis_to_pgrid(
return blocks


def compute_scatter_color(parent_grid_variables: List[symbolic.symbol],
parent_grid_shape: List[int],
matched_dimensions: MatchedDimensions) -> List[bool]:
# avoid import loop
from .node import FULLY_REPLICATED_RANK

# We need to setup a broadcast grid to make sure the ranks on the
# remaining dimensions of the scatter grid get their values. We will
# split the process grid into two subgrids to achieve this

dim_to_idx = {
v: i
for i, v in enumerate(parent_grid_variables)
if v.name != FULLY_REPLICATED_RANK
}

# these are the dimensions we need to scatter our data over (i.e. not replicated)
scattered_dims = {
dim_to_idx[s]
for s, _ in matched_dimensions if s is not None
}

# these are the number of unpartitioned dimensions in our subset
# these need to be mapped to a dimension of the process grid with size one
required_full_replication_dims = len(matched_dimensions) - len(
scattered_dims)

# empty dims are the the indices of dims of size 1 in the global pgrid
empty_dims = {
i
for i, s in enumerate(parent_grid_variables)
if s.name == FULLY_REPLICATED_RANK
}
assert all(parent_grid_shape[i] == 1 for i in empty_dims)

# the indices of size 1 dimensions we choose for the scatter grid
empty_scatter_dims = set()
for i in range(required_full_replication_dims):
# choose any empty rank
parent_empty_dim = next(iter(empty_dims))
empty_dims.remove(parent_empty_dim)
scattered_dims.add(parent_empty_dim)
empty_scatter_dims.add(parent_empty_dim)

# scatter_color[i] == True iff the i'th dimension is kept in the scatter grid
scatter_color = [
i in scattered_dims for i in range(len(parent_grid_shape))
]
return scatter_color


def try_construct_subarray(
def try_match_constraint(
sdfg: SDFG, state: SDFGState, pgrid_name: str, global_desc: data.Data,
subset: subsets.Range, grid_variables: List[symbolic.symbol],
scatter: bool,
dry_run: bool) -> Optional[Tuple[str, str, Optional[str]]]:
subset: subsets.Range,
grid_variables: List[symbolic.symbol]) -> List[grid_array.AxisScheme]:
"""
Try to convert the given end of the distributed memlet to a subarray,
creating the necessary process grids. Returns the name of the subarray, the
name of the scatter grid and the name of the bcast/reduction grid.
Try to parse the the given communication constraint as a grid mapped array.

:param sdfg: The SDFG to operate on.
:param state: The state to operate on (Dummy tasklets will be inserted
Expand All @@ -152,121 +100,59 @@ def try_construct_subarray(
:param subset: The end of the distributed memlet to convert.
:param grid_variables: The process grid corresponding to the computation to
which this is either an input or an output.
:param scatter: True if this is for a scatter.
:param dry_run: If True, don't actually create the grids and subarray.
Instead, return None.
:return: The name of the subarray, the name of the scatter grid and the
name of the bcast grid (possibly ``None`` if no broadcast is
necessary).
"""
:return: the axis scheme mapping implementing the communication constraint.

# squeeze dimensions
squeeze = lambda shape: [s for s in shape if s > 1]
:raises: CommunicationSolverException if it cannot be solved.
"""

global_shape = squeeze(global_desc.shape)
global_shape = global_desc.shape
pgrid = sdfg.process_grids[pgrid_name]

matched_dimensions = match_subset_axis_to_pgrid(subset, grid_variables)
# index by grid variable
matched_dimensions = [(v, bs)
for i, (v, bs) in enumerate(matched_dimensions)
if global_desc.shape[i] > 1]

if len(matched_dimensions) > len(pgrid.shape):
# haven't thought about this yet
raise NotImplementedError()
elif len(matched_dimensions) < len(pgrid.shape):

scatter_color = compute_scatter_color(grid_variables, pgrid.shape,
matched_dimensions)
assert sum(scatter_color) == len(global_shape)
# The broadcast grid provides the data to the ranks that are not
# covered in the scatter grid.
# It's always the inverse of the scatter grid.
bcast_color = list(map(lambda x: not x, scatter_color))

subgrid_shape = lambda color: [
s for s, c in zip(pgrid.shape, color) if c
]
scatter_shape = subgrid_shape(scatter_color)
bcast_shape = subgrid_shape(bcast_color)

if not dry_run:
scatter_grid_name = sdfg.add_pgrid(
shape=scatter_shape,
parent_grid=pgrid_name,
color=scatter_color,
exact_grid=None if scatter else 0)

bcast_grid_name = sdfg.add_pgrid(shape=bcast_shape,
parent_grid=pgrid_name,
color=bcast_color)
for name, shape in ((scatter_grid_name, scatter_shape),
(bcast_grid_name, bcast_shape)):
distr_utils.initialize_fields(state, [
f'MPI_Comm {name}_comm;',
f'MPI_Group {name}_group;',
f'int {name}_coords[{len(shape)}];',
f'int {name}_dims[{len(shape)}];',
f'int {name}_rank;',
f'int {name}_size;',
f'bool {name}_valid;',
])
current_idx = 0
symbol_to_idx = {}
for i, s in enumerate(grid_variables):
if scatter_color[i]:
symbol_to_idx[s] = current_idx
current_idx += 1
else:
scatter_shape = pgrid.shape
scatter_grid_name = pgrid_name
bcast_grid_name = None
symbol_to_idx = {s: i for i, s in enumerate(grid_variables)}

# The indices of dims of size 1 in the scatter_grid
empty_dims = {i for i, s in enumerate(scatter_shape) if s == 1}

# Now we need to map dimensions of the subset to the dimensions in the
# scatter grid
correspondence: List[int] = []

for i, (p, bs) in enumerate(matched_dimensions):
if p is None:
# we can choose to map to any of the empty ranks
chosen = next(iter(empty_dims))
empty_dims.remove(chosen)
correspondence.append(chosen)
else:
grid_index = symbol_to_idx[p]

global_block_size = global_shape[i] / scatter_shape[grid_index]
if global_block_size % bs != 0:
return None

if bs != global_block_size:
raise CommunicationSolverException(
f"Detected block size {bs} (on axis {subset[i]}) does not"
"match global block size {global_block_size}")
correspondence.append(grid_index)

assert all(map(lambda i: 0 <= i < len(scatter_shape), correspondence))

if not dry_run:
# create subarray
subarray_name = sdfg.add_subarray(dtype=global_desc.dtype,
shape=global_shape,
subshape=squeeze(
subset.size_exact()),
pgrid=scatter_grid_name,
correspondence=correspondence)
distr_utils.initialize_fields(state, [
f'MPI_Datatype {subarray_name};', f'int* {subarray_name}_counts;',
f'int* {subarray_name}_displs;'
])

return subarray_name, scatter_grid_name, bcast_grid_name
else:
return None
axis_mapping: List[Optional[grid_array.AxisScheme]] = [None] * len(
pgrid.shape)

# Assign each partitioned axis
for i, (matched_dim, matched_bs) in enumerate(matched_dimensions):
if matched_dim is None:
# this is a replicated dimension, we assign these in a second pass
# since they have no block size constraint
continue
grid_index = grid_variables.index(matched_dim)
global_block_size = global_shape[i] / pgrid.shape[grid_index]
if matched_bs != global_block_size:
raise CommunicationSolverException(
f"Detected block size {matched_bs} (on axis {subset[i]}) does not"
"match global block size {global_block_size}")

scheme = grid_array.AxisScheme(axis=i,
scheme=grid_array.AxisType.PARTITION)
axis_mapping[grid_index] = scheme

unassigned_dims = {i for i, a in enumerate(axis_mapping) if a is None}

# Assign replicated axes
for i, (matched_dim, matched_bs) in enumerate(matched_dimensions):
if matched_dim is not None:
continue
grid_index = unassigned_dims.pop()
scheme = grid_array.AxisScheme(axis=i,
scheme=grid_array.AxisType.REPLICATE)
axis_mapping[grid_index] = scheme

# The process grid axes are broadcast
for i, scheme in enumerate(axis_mapping):
if scheme is not None:
continue
scheme = grid_array.AxisScheme(axis=None,
scheme=grid_array.AxisType.BROADCAST)
axis_mapping[i] = scheme
return axis_mapping


class CommunicateSubArrays(pm.ExpandTransformation):
Expand All @@ -286,25 +172,13 @@ def can_be_applied(self, state: SDFGState, *_, **__):
try:
if node.src_pgrid is not None:
garr = sdfg.arrays[node.src_global_array]
try_construct_subarray(sdfg,
state,
node.src_pgrid,
garr,
node.src_subset,
src_vars,
False,
dry_run=True)
try_match_constraint(sdfg, state, node.src_pgrid, garr,
node.src_subset, src_vars)

if node.dst_pgrid is not None:
garr = sdfg.arrays[node.dst_global_array]
try_construct_subarray(sdfg,
state,
node.dst_pgrid,
garr,
node.dst_subset,
dst_vars,
False,
dry_run=True)
try_match_constraint(sdfg, state, node.dst_pgrid, garr,
node.dst_subset, dst_vars)
except CommunicationSolverException:
return False

Expand Down Expand Up @@ -342,26 +216,14 @@ def expansion(node: 'DistributedMemlet', state: SDFGState, sdfg: SDFG):
subset = node.src_subset
rvars = src_vars

subarray_name, scatter_grid, bcast_grid = try_construct_subarray(
sdfg,
state,
pgrid_name,
garr,
subset,
rvars,
scatter,
dry_run=False)
axis_mapping = try_match_constraint(sdfg, state, pgrid_name, garr,
subset, rvars)

if scatter:
expansion = mpi.BlockScatter(node.label,
subarray_type=subarray_name,
scatter_grid=scatter_grid,
bcast_grid=bcast_grid)
else:
expansion = mpi.BlockGather(node.label,
subarray_type=subarray_name,
gather_grid=scatter_grid,
reduce_grid=bcast_grid)
cls = grid_array.ScatterOntoGrid if scatter else grid_array.GatherFromGrid

expansion = cls(node.label,
grid_name=pgrid_name,
axis_mapping=axis_mapping)

# clean up connectors to match the new node
expansion.add_in_connector("_inp_buffer")
Expand Down
7 changes: 5 additions & 2 deletions tests/distributed/test_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
[2],
[4],
])
def test_elementwise_1d(sizes):
def test_elementwise_1d(sizes, sdfg_name):
@dace
def program(x: dace.int64[64]):
return x + 5

sdfg = program.to_sdfg()
sdfg.name = sdfg_name

map_entry = find_map_containing(sdfg, "")
schedule.lower(sdfg, {map_entry: sizes})
Expand All @@ -30,6 +31,7 @@ def program(x: dace.int64[64]):
compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes))


@pytest.mark.skip("Not implemented")
@pytest.mark.parametrize("sizes", [
[2],
])
Expand Down Expand Up @@ -62,14 +64,15 @@ def program(x: dace.int64[64, 1]):
[2, 2, 2], # no broadcast grid, 2d scatter grid
[1, 2, 1],
])
def test_bcast_simple(sizes):
def test_bcast_simple(sizes, sdfg_name):
@dace
def program(x: dace.int64[4, 8, 16], y: dace.int64[8, 16]):
return x + y

sdfg = program.to_sdfg()
sdfg.expand_library_nodes()
sdfg.simplify()
sdfg.name = sdfg_name

map_entry = find_map_containing(sdfg, "")
schedule.lower(sdfg, {map_entry: sizes})
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.