diff --git a/.github/workflows/cpu-ci.yml b/.github/workflows/cpu-ci.yml index 52e81a97..8223308f 100644 --- a/.github/workflows/cpu-ci.yml +++ b/.github/workflows/cpu-ci.yml @@ -58,7 +58,7 @@ jobs: - name: Test with pytest env: ORT_RELEASE: ${{ github.workspace }}/onnxruntime-daceml-patched - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx" --timeout=500 + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx and not mpi" --timeout=500 run: make test - name: Test with doctest @@ -95,7 +95,7 @@ jobs: - name: Test with pytest env: - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx" --timeout=500 --skip-ort + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx and not mpi" --timeout=500 --skip-ort run: make test - name: Upload coverage diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 93282b78..f836a771 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -29,7 +29,7 @@ jobs: - name: Test with pytest env: - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc --gpu-only -m "not slow and not fpga and not xilinx and not onnx" --timeout=500 + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc --gpu-only -m "not slow and not fpga and not xilinx and not onnx and not mpi" --timeout=500 run: make test - name: Upload coverage diff --git a/daceml/distributed/communication/subarrays.py b/daceml/distributed/communication/subarrays.py index 633a499f..1948bcde 100644 --- a/daceml/distributed/communication/subarrays.py +++ b/daceml/distributed/communication/subarrays.py @@ -137,6 +137,7 @@ def compute_scatter_color(parent_grid_variables: List[symbolic.symbol], def try_construct_subarray( sdfg: SDFG, state: SDFGState, pgrid_name: str, global_desc: data.Data, subset: subsets.Range, grid_variables: List[symbolic.symbol], + scatter: bool, dry_run: bool) -> Optional[Tuple[str, str, Optional[str]]]: """ Try to convert the given end of the distributed memlet to a subarray, @@ -151,6 +152,7 @@ def try_construct_subarray( :param subset: The end of the distributed memlet to convert. :param grid_variables: The process grid corresponding to the computation to which this is either an input or an output. + :param scatter: True if this is for a scatter. :param dry_run: If True, don't actually create the grids and subarray. Instead, return None. :return: The name of the subarray, the name of the scatter grid and the @@ -182,13 +184,15 @@ def try_construct_subarray( bcast_shape = subgrid_shape(bcast_color) if not dry_run: - scatter_grid_name = sdfg.add_pgrid(shape=scatter_shape, - parent_grid=pgrid_name, - color=scatter_color) + scatter_grid_name = sdfg.add_pgrid( + shape=scatter_shape, + parent_grid=pgrid_name, + color=scatter_color, + exact_grid=None if scatter else 0) + bcast_grid_name = sdfg.add_pgrid(shape=bcast_shape, parent_grid=pgrid_name, color=bcast_color) - for name, shape in ((scatter_grid_name, scatter_shape), (bcast_grid_name, bcast_shape)): distr_utils.initialize_fields(state, [ @@ -280,6 +284,7 @@ def can_be_applied(self, state: SDFGState, *_, **__): garr, node.src_subset, src_vars, + False, dry_run=True) if node.dst_pgrid is not None: @@ -290,6 +295,7 @@ def can_be_applied(self, state: SDFGState, *_, **__): garr, node.dst_subset, dst_vars, + False, dry_run=True) except CommunicationSolverException: return False @@ -329,7 +335,14 @@ def expansion(node: 'DistributedMemlet', state: SDFGState, sdfg: SDFG): rvars = src_vars subarray_name, scatter_grid, bcast_grid = try_construct_subarray( - sdfg, state, pgrid_name, garr, subset, rvars, dry_run=False) + sdfg, + state, + pgrid_name, + garr, + subset, + rvars, + scatter, + dry_run=False) if scatter: expansion = mpi.BlockScatter(node.label, @@ -339,7 +352,7 @@ def expansion(node: 'DistributedMemlet', state: SDFGState, sdfg: SDFG): else: expansion = mpi.BlockGather(node.label, subarray_type=subarray_name, - gather_grid=pgrid_name, + gather_grid=scatter_grid, reduce_grid=bcast_grid) # clean up connectors to match the new node diff --git a/daceml/distributed/schedule.py b/daceml/distributed/schedule.py index a4d72e1a..e39b00c4 100644 --- a/daceml/distributed/schedule.py +++ b/daceml/distributed/schedule.py @@ -2,11 +2,14 @@ A distributed schedule for a subgraph is a mapping from each map entry node to process grid dimensions. """ +import copy import collections -from typing import List, Tuple, Dict, Set, Optional +from typing import List, Tuple, Dict, Set, Optional, Union import itertools -import sympy +import networkx as nx + +import sympy as sp from dace import nodes, SDFG, SDFGState, symbolic, subsets, memlet, data from dace.sdfg import propagation, utils as sdfg_utils @@ -59,7 +62,7 @@ def compute_tiled_map_range(nmap: nodes.Map, raise NotImplementedError("Cannot tile map with step") # check that we divide evenly - if sympy.Mod(exact_size, block_size).simplify() != 0: + if sp.Mod(exact_size, block_size).simplify() != 0: raise ValueError(f"{exact_size} is not divisible by {block_size}") td_to_new = exact_size // block_size - 1 td_step_new = td_step @@ -154,7 +157,11 @@ def propagate_rank_local_subsets( def rank_tile_map( - sdfg: SDFG, state: SDFGState, map_nodes: MapNodes, num_blocks: NumBlocks + sdfg: SDFG, + state: SDFGState, + map_nodes: MapNodes, + num_blocks: NumBlocks, + force_local_names: Optional[Dict[str, str]] = None, ) -> Tuple[RankVariables, GlobalToLocal, GlobalToLocal]: """ Tile the map according to the given block sizes, create rank-local views @@ -169,6 +176,7 @@ def rank_tile_map( :param state: The state to operate on. :param map_nodes: The map nodes to operate on. :param num_blocks: The number of blocks to tile the map with in each dimension. + :param force_local_names: A dictionary mapping global array names to local array names. :returns: created symbolic variables for each rank axis, and two lists of tuples associating global AccessNodes, local AccessNodes and the symbolic communication constraints. The first is for reads, the @@ -217,15 +225,18 @@ def rank_tile_map( elif isinstance(global_node.desc(sdfg), data.View): raise NotImplementedError("Cannot handle views yet") - # Create the rank-local array - local_name, _ = sdfg.add_transient( - name="local_" + global_name, - shape=[ - symbolic.overapproximate(r).simplify() - for r in new_subset.size_exact() - ], - dtype=sdfg.arrays[global_name].dtype, - find_new_name=True) + if force_local_names is not None and global_name in force_local_names: + # reuse descriptor + local_name = force_local_names[global_name] + else: + local_name, _ = sdfg.add_transient( + name="local_" + global_name, + shape=[ + symbolic.overapproximate(r).simplify() + for r in new_subset.size_exact() + ], + dtype=sdfg.arrays[global_name].dtype, + find_new_name=True) local_node = state.add_access(local_name) if is_input: @@ -247,6 +258,238 @@ def rank_tile_map( return rank_variables, result_read, result_write +def ordered_nodes_by_state( + sdfg: SDFG +) -> Dict[SDFGState, List[Union[nodes.Map, nodes.NestedSDFG]]]: + ordered_maps = {} + + for state in sdfg.nodes(): + top_level_nodes = set(state.scope_children()[None]) + result_nodes = [ + node.map if isinstance(node, nodes.MapEntry) else node + for node in sdfg_utils.dfs_topological_sort(state) + if isinstance(node, (nodes.MapEntry, + nodes.NestedSDFG)) and node in top_level_nodes + ] + ordered_maps[state] = result_nodes + return ordered_maps + + +def rank_tile_nested( + sdfg: SDFG, + state: SDFGState, + nnode: nodes.NestedSDFG, + schedule: DistributedSchedule, +) -> Tuple[NumBlocks, RankVariables, GlobalToLocal, GlobalToLocal]: + """ + Rank tile a nested SDFG. To make this viable in the cases we want to support + (typically just initialization for reductions) there are the following constraints: + + * Only a single ProcessGrid is used for the whole SDFG. + * No communication is added within the NSDFG (this enables fusability outside of it) + * As a result, schedules in the NSDFG must be 'consistent', in that if they + write to the same array, the communication constraints must be the same + """ + + nsdfg = nnode.sdfg + + # rule out unsqueezes + for edge in itertools.chain(state.in_edges(nnode), state.out_edges(nnode)): + if not utils.all_equal(edge.data.subset.size_exact(), + sdfg.arrays[edge.data.data].shape): + raise NotImplementedError( + "Cannot handle unsqueezes on NestedSDFG connectors") + + ordered_maps = ordered_nodes_by_state(nsdfg) + + rank_variables: Dict[str, int] = {} + num_fully_replicated: int = 0 + + constraints: Dict[str, subsets.Range] = {} + global_to_local: Dict[str, str] = {} + + # Rank tile each map, collecting the constraints by array + for nstate, map_nodes in ordered_maps.items(): + for map_node in map_nodes: + num_blocks = schedule[map_node] + + if len(num_blocks) != map_node.get_param_num(): + raise ValueError( + f"Schedule for {map_node} has {len(num_blocks)} " + f"block sizes, but {map_node.get_param_num()} are " + "required.") + + map_nodes = find_map_nodes(nstate, map_node) + + # modify the map to tile it + map_rank_variables, reads, writes = rank_tile_map( + nsdfg, + nstate, + map_nodes, + num_blocks, + force_local_names=global_to_local) + + # gather our current constraints, and delete the global access nodes + current_constraints = {} + for nglobal, nlocal, subset in reads: + current_constraints[nglobal.data] = subset + # update global mapping + global_to_local[nglobal.data] = nlocal.data + # delete global access node + nstate.remove_node(nglobal) + + for nglobal, nlocal, subset in writes: + # update global mapping + global_to_local[nglobal.data] = nlocal.data + if nglobal.data in current_constraints and current_constraints[ + nglobal.data] != subset: + raise ValueError( + "Inconsistent constraints found within map") + current_constraints[nglobal.data] = subset + nstate.remove_node(nglobal) + + # the tiled map now has its own rank variables. We now need to + # these variables consistent with the outer variables since the + # whole nested SDFG only has one ProcessGrid + me, mx = map_nodes + + renaming = {} + + map_block_size_per_symbol = { + s.name: bs + for s, bs in zip(map_rank_variables, num_blocks) + } + # the arrays that already have constraints from another map + for name, subset in current_constraints.items(): + if name not in constraints: + constraints[name] = subset + continue + + previous_constraint = constraints[name] + + is_symbol = lambda x: x.is_symbol + # try to match the constraints + + used_symbols = set(subset.free_symbols) + used_symbols |= set(previous_constraint.free_symbols) + + def new_symbol_name(name): + name = utils.find_str_not_in_set(used_symbols, name) + used_symbols.add(name) + return name + + pattern = copy.deepcopy(subset) + wilds = { + s: sp.Wild(new_symbol_name(f"i{i}"), + properties=[ + is_symbol, lambda x: rank_variables[x.name] + == map_block_size_per_symbol[s] + ]) + for i, s in enumerate(subset.free_symbols) + } + pattern.replace(wilds) + + def range_to_basic_expr(self): + exprs = [] + for (rb, re, rs), ts in zip(self.ranges, self.tile_sizes): + exprs.append(rb) + exprs.append(re) + exprs.append(ts) + return sp.Basic(*exprs) + + match = range_to_basic_expr(previous_constraint).match( + range_to_basic_expr(pattern)) + if match is None: + raise ValueError( + "Inconsistent constraints found within map") + + renaming.update(match) + + map_num_replicated = 0 + # now check that the block sizes are consistent + # The should already be consistent due to the sympy matching above, + # but let's make sure anyway + for v, bs in zip(map_rank_variables, num_blocks): + if v in renaming: + v = renaming[v] + + if v.name == node.FULLY_REPLICATED_RANK: + map_num_replicated += 1 + continue + + if v.name in rank_variables: + if rank_variables[v.name] != bs: + raise ValueError( + "Inconsistent block sizes found within NestedSDFG") + else: + rank_variables[v.name] = bs + num_fully_replicated = max(num_fully_replicated, + map_num_replicated) + + # Concretize process grid ordering + ordered_rank_variables = [] + block_sizes = [] + for ov, bs in rank_variables.items(): + ordered_rank_variables.append(ov) + block_sizes.append(bs) + ordered_rank_variables += [node.FULLY_REPLICATED_RANK + ] * num_fully_replicated + block_sizes += [1] * num_fully_replicated + + # Swap out the global connectors for local view connectors on the outside of the NSDFG + nnode.in_connectors = { + global_to_local[k]: v + for k, v in nnode.in_connectors.items() + } + nnode.out_connectors = { + global_to_local[k]: v + for k, v in nnode.out_connectors.items() + } + + # create outer transients for the local views + to_iter = itertools.chain( + zip(state.in_edges(nnode), itertools.repeat(True)), + zip(state.out_edges(nnode), itertools.repeat(False))) + + reads = [] + writes = [] + for edge, is_read in to_iter: + global_name = edge.dst_conn if is_read else edge.src_conn + local_name_in_nsdfg = global_to_local[global_name] + inner_desc = nsdfg.arrays[local_name_in_nsdfg] + outer_local_name = sdfg.add_datadesc( + name=local_name_in_nsdfg, + datadesc=copy.deepcopy(inner_desc), + find_new_name=True) + access = state.add_access(outer_local_name) + + if is_read: + reads.append((edge.src, access, constraints[global_name])) + redirect_args = dict(new_dst_conn=local_name_in_nsdfg, + new_src=access) + else: + writes.append((edge.dst, access, constraints[global_name])) + redirect_args = dict(new_src_conn=local_name_in_nsdfg, + new_dst=access) + + new_edge = xfh.redirect_edge(state, + edge, + new_data=outer_local_name, + **redirect_args) + new_edge.data.subset = constraints[global_name] + + for global_name, local_name in global_to_local.items(): + if not nsdfg.arrays[global_name].transient: + nsdfg.arrays[local_name].transient = False + del nsdfg.arrays[global_name] + + # specialize process grid vars inside the nsdfg + sdfg.specialize({v: 0 for v in ordered_rank_variables}) + ordered_rank_variables = list( + map(symbolic.pystr_to_symbolic, ordered_rank_variables)) + return block_sizes, ordered_rank_variables, reads, writes + + def lower(sdfg: SDFG, schedule: DistributedSchedule): """ Attempt to lower the SDFG to a SPMD MPI SDFG to distribute computation. @@ -257,39 +500,43 @@ def lower(sdfg: SDFG, schedule: DistributedSchedule): :param schedule: The schedule to use. :note: Operates in-place. """ - missing = set(distr_utils.all_top_level_maps(sdfg)).difference( schedule.keys()) + if missing: raise ValueError( - f"Missing schedule for maps {', '.join(map(str, missing))}") + f"Missing schedule for maps: {', '.join(map(lambda x: x.label, missing))}" + ) # Order the schedule topologically for each state - ordered_maps: Dict[SDFGState, List[nodes.Map]] = {} - - for state in sdfg.nodes(): - top_level_nodes = set(state.scope_children()[None]) - map_entries = [ - node.map for node in sdfg_utils.dfs_topological_sort(state) - if isinstance(node, nodes.MapEntry) and node in top_level_nodes - ] - ordered_maps[state] = map_entries + ordered_nodes = ordered_nodes_by_state(sdfg) # each map has a main process grid # with the dimension given by the schedule - maps_to_pgrids: Dict[nodes.Map, Tuple[RankVariables, str]] = {} - - for state, map_nodes in ordered_maps.items(): - for map_node in map_nodes: - num_blocks = schedule[map_node] - - if len(num_blocks) != map_node.get_param_num(): - raise ValueError( - f"Schedule for {map_node} has {len(num_blocks)} " - f"block sizes, but {map_node.get_param_num()} are " - "required.") - - map_nodes = find_map_nodes(state, map_node) + process_grids: Dict[str, RankVariables] = {} + + for state, top_level_nodes in ordered_nodes.items(): + for top_level_node in top_level_nodes: + if isinstance(top_level_node, nodes.NestedSDFG): + num_blocks, rank_variables, reads, writes = rank_tile_nested( + sdfg, state, top_level_node, schedule) + else: + map_node: nodes.Map = top_level_node + num_blocks = schedule[map_node] + + if len(num_blocks) != map_node.get_param_num(): + raise ValueError( + f"Schedule for {map_node} has {len(num_blocks)} " + f"block sizes, but {map_node.get_param_num()} are " + "required.") + + map_nodes = find_map_nodes(state, map_node) + + # modify the map to tile it + rank_variables, reads, writes = rank_tile_map( + sdfg, state, map_nodes, num_blocks) + rank_variable_names: List[str] = list( + map(lambda s: s.name, rank_variables)) # Create a process grid that will be used for communication process_grid_name = sdfg.add_pgrid(shape=num_blocks) @@ -303,27 +550,16 @@ def lower(sdfg: SDFG, schedule: DistributedSchedule): f'bool {process_grid_name}_valid;', ]) - # modify the map to tile it - rank_variables, reads, writes = rank_tile_map( - sdfg, state, map_nodes, num_blocks) - rank_variable_names: List[str] = list( - map(lambda s: s.name, rank_variables)) - - maps_to_pgrids[map_node] = rank_variables, process_grid_name + process_grids[process_grid_name] = rank_variables to_iter = itertools.chain(zip(reads, itertools.repeat(True)), zip(writes, itertools.repeat(False))) for (nglobal, nlocal, subset), is_read in to_iter: if not is_read: - # we need to check if this array was written before. - # If it was written before, and the current write is - # write-conflicted, we need to communicate the previously - # written values - if analysis.is_previously_written(sdfg, state, nlocal, - nglobal.data): - raise NotImplementedError( - "In place updates not supported yet") + # FIXME We don't need to communicate this if it is a + # derived schedule and one of our siblings comes after us. + pass full_subset = subsets.Range.from_array(nglobal.desc(sdfg)) @@ -355,11 +591,13 @@ def lower(sdfg: SDFG, schedule: DistributedSchedule): state.add_edge(comm, None, dst, None, sdfg.make_array_memlet(dst.data)) + sdfg.validate() utils.expand_nodes( sdfg, predicate=lambda n: isinstance(n, node.DistributedMemlet)) + sdfg.validate() # Now that we are done lowering, we can instatiate the process grid # variables with zero, since each rank only sees its section of the array - for _, (variables, _) in maps_to_pgrids.items(): + for _, variables in process_grids.items(): repl_dict = {v.name: 0 for v in variables} sdfg.specialize(repl_dict) diff --git a/daceml/distributed/utils.py b/daceml/distributed/utils.py index 7608b657..e92560e0 100644 --- a/daceml/distributed/utils.py +++ b/daceml/distributed/utils.py @@ -1,8 +1,13 @@ -from typing import List, Iterator +from typing import List, Iterator, Tuple, Union, Dict + +import numpy as np +import pytest import dace from dace import SDFG, SDFGState, nodes, dtypes +from dace.sdfg import utils as sdfg_utils from dace.libraries import mpi +from daceml.util import utils def initialize_fields(state: SDFGState, fields: List[str]): @@ -24,11 +29,20 @@ def initialize_fields(state: SDFGState, fields: List[str]): dace.Memlet.from_array(dummy_name, scal)) -def all_top_level_maps(sdfg: SDFG) -> Iterator[nodes.Map]: +def all_top_level_maps( + sdfg: SDFG, + yield_parent=False +) -> Iterator[Union[nodes.Map, Tuple[nodes.Map, SDFGState]]]: for state in sdfg.nodes(): for node in state.scope_children()[None]: if isinstance(node, nodes.MapEntry): - yield node.map + if yield_parent: + yield node.map, state + else: + yield node.map + elif isinstance(node, nodes.NestedSDFG): + # also check top-level nested SDFGs + yield from all_top_level_maps(node.sdfg, yield_parent) def add_debug_rank_cords_tasklet(sdfg: SDFG): @@ -37,7 +51,10 @@ def add_debug_rank_cords_tasklet(sdfg: SDFG): """ new_state = sdfg.add_state_before(sdfg.start_state, 'debug_MPI') - code = "{\n" + code = """{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + """ for grid_name, desc in sdfg.process_grids.items(): code += """ @@ -49,7 +66,7 @@ def add_debug_rank_cords_tasklet(sdfg: SDFG): printf("\\n"); }} - printf("Hello from rank %d in grid {grid_name}, with coords: ", __state->{grid_name}_rank); + printf("Hello from global rank %d, rank %d in grid {grid_name}, with coords: ", rank, __state->{grid_name}_rank); for (int i = 0; i < {grid_dims}; i++) {{ printf("%d ", __state->{grid_name}_coords[i]); @@ -72,3 +89,83 @@ def add_debug_rank_cords_tasklet(sdfg: SDFG): wnode = new_state.add_write(dummy_name) new_state.add_edge(tasklet, '__out', wnode, None, dace.Memlet.from_array(dummy_name, scal)) + + +def add_debugprint_tasklet(sdfg: SDFG, state: SDFGState, + node: nodes.AccessNode): + """ + Insert a tasklet that just prints out the given data + """ + desc = node.desc(sdfg) + loops = "\n".join("for (int i{v} = 0; i{v} < {s}; i{v}++)".format(v=i, s=s) + for i, s in enumerate(desc.shape)) + code = """ + {loops} + {{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + printf("RANK %d: {name}[%d] = %d\\n", rank, {indices}, {name}[{indices}]); + }} + """.format(name=node.data, + loops=loops, + indices=",".join( + ["i{}".format(i) for i in range(len(desc.shape))])) + + # add a tasklet that writes to nothing + tasklet = nodes.Tasklet("print_" + node.data, {"__in"}, {"__out"}, code, + dtypes.Language.CPP) + + state.add_node(tasklet) + state.add_edge(node, None, tasklet, '__in', + sdfg.make_array_memlet(node.data)) + + # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. + dummy_name, scal = sdfg.add_scalar("dummy", + dace.int32, + transient=True, + find_new_name=True) + state.add_edge(tasklet, '__out', state.add_write(dummy_name), None, + dace.Memlet.from_array(dummy_name, scal)) + + +def arange_with_size(size): + return np.arange(utils.prod(size), dtype=np.int64).reshape(size).copy() + + +def find_map_containing(sdfg, name) -> nodes.Map: + cands = [] + for node, state in all_top_level_maps(sdfg, yield_parent=True): + if name in node.label: + cands.append(node) + if len(cands) == 1: + return cands[0] + else: + raise ValueError("Found {} candidates for map name {}".format( + len(cands), name)) + + +def compile_and_call(sdfg, inputs: Dict[str, np.ndarray], + expected_output: np.ndarray, num_required_ranks: int): + MPI = pytest.importorskip("mpi4py.MPI") + commworld = MPI.COMM_WORLD.Dup() + commworld.Barrier() + rank = commworld.Get_rank() + size = commworld.Get_size() + + if size < num_required_ranks: + raise ValueError( + "This test requires at least {} ranks".format(num_required_ranks)) + + func = sdfg_utils.distributed_compile(sdfg, commworld) + + if rank == 0: + result = func(**inputs) + np.testing.assert_allclose(result, expected_output) + else: + dummy_inputs = { + k: np.zeros_like(v, shape=(1, )) + for k, v in inputs.items() + } + func(**dummy_inputs) + commworld.Barrier() + commworld.Free() diff --git a/tests/distributed/test_elementwise.py b/tests/distributed/test_elementwise.py new file mode 100644 index 00000000..f9985b8e --- /dev/null +++ b/tests/distributed/test_elementwise.py @@ -0,0 +1,59 @@ +""" +These tests expect to be with 4 ranks +""" +import pytest + +import dace + +from daceml.util import utils +from daceml.distributed import schedule +from daceml.distributed.utils import find_map_containing, compile_and_call, arange_with_size + + +@pytest.mark.parametrize("sizes", [ + [2], + [4], +]) +def test_elementwise_1d(sizes): + @dace + def program(x: dace.int64[64]): + return x + 5 + + sdfg = program.to_sdfg() + + map_entry = find_map_containing(sdfg, "") + schedule.lower(sdfg, {map_entry: sizes}) + + X = arange_with_size([64]) + expected = X + 5 + compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes)) + + +@pytest.mark.parametrize( + "sizes", + [ + [2, 1, 1], # fully replicate broadcasted => use 2d broadcast grid + [2, 2, 1], # no broadcast grid, 1d scatter grid + [2, 2, 2], # no broadcast grid, 2d scatter grid + [1, 2, 1], + ]) +def test_bcast_simple(sizes): + @dace + def program(x: dace.int64[4, 8, 16], y: dace.int64[8, 16]): + return x + y + + sdfg = program.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + map_entry = find_map_containing(sdfg, "") + schedule.lower(sdfg, {map_entry: sizes}) + + X = arange_with_size([4, 8, 16]) + Y = arange_with_size([8, 16]) + expected = X + Y + + compile_and_call(sdfg, { + 'x': X.copy(), + 'y': Y.copy() + }, expected, utils.prod(sizes)) diff --git a/tests/distributed/test_lower_schedule.py b/tests/distributed/test_lower_schedule.py deleted file mode 100644 index 0f8f7501..00000000 --- a/tests/distributed/test_lower_schedule.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -These tests expect to be with 4 ranks -""" -import pytest -import numpy as np - -import dace -from dace.sdfg import utils as sdfg_utils - -from daceml.util import utils -from daceml.distributed import schedule, utils as distr_utils - -MPI = pytest.importorskip("mpi4py.MPI") - - -def arange_with_size(size): - return np.arange(utils.prod(size), dtype=np.int64).reshape(size).copy() - - -@pytest.mark.parametrize("sizes", [ - [2], - [4], -]) -def test_elementwise_1d(sizes): - assert utils.prod(sizes) <= 4 - - @dace - def program(x: dace.int64[64]): - return x + 5 - - sdfg = program.to_sdfg() - - map_entry = [n for n in sdfg.node(0).scope_children().keys() if n][0] - schedule.lower(sdfg, {map_entry.map: sizes}) - - X = arange_with_size([64]) - expected = X + 5 - - commworld = MPI.COMM_WORLD - rank = commworld.Get_rank() - size = commworld.Get_size() - if size < utils.prod(sizes): - raise ValueError("This test requires at least {} ranks".format( - utils.prod(sizes))) - - func = sdfg_utils.distributed_compile(sdfg, commworld) - - if rank == 0: - result = func(x=X.copy()) - np.testing.assert_allclose(result, expected) - else: - func(x=np.zeros((1, ), dtype=np.int64)) - - -@pytest.mark.parametrize( - "sizes", - [ - [2, 1, 1], # fully replicate broadcasted => use 2d broadcast grid - [2, 2, 1], # no broadcast grid, 1d scatter grid - [2, 2, 2], # no broadcast grid, 2d scatter grid - [1, 2, 1], - ]) -def test_bcast_simple(sizes): - @dace - def program(x: dace.int64[4, 8, 16], y: dace.int64[8, 16]): - return x + y - - sdfg = program.to_sdfg() - sdfg.expand_library_nodes() - sdfg.simplify() - - map_entry = [n for n in sdfg.node(0).scope_children().keys() if n][0] - schedule.lower(sdfg, {map_entry.map: sizes}) - - X = arange_with_size([4, 8, 16]) - Y = arange_with_size([8, 16]) - expected = X + Y - - commworld = MPI.COMM_WORLD - rank = commworld.Get_rank() - size = commworld.Get_size() - if size < utils.prod(sizes): - raise ValueError("This test requires at least {} ranks".format( - utils.prod(sizes))) - - func = sdfg_utils.distributed_compile(sdfg, commworld) - - if rank == 0: - result = func(x=X.copy(), y=Y.copy()) - np.testing.assert_allclose(result, expected) - else: - func(x=np.zeros((1, ), dtype=np.int64), - y=np.zeros((1, ), dtype=np.int64)) diff --git a/tests/distributed/test_nested.py b/tests/distributed/test_nested.py new file mode 100644 index 00000000..cb4c4bbc --- /dev/null +++ b/tests/distributed/test_nested.py @@ -0,0 +1,73 @@ +import pytest +import numpy as np + +import dace +from dace import nodes +from dace.transformation import interstate +from dace.sdfg import utils as sdfg_utils + +from daceml.util import utils +from daceml.distributed import schedule +from daceml.distributed.utils import find_map_containing, compile_and_call, arange_with_size + + +@pytest.mark.parametrize( + "sizes", + [ + [1, 1], + [2, 1], + [2, 2], # parallelize along the reduction axis with MPI reduce + [2, 4], # parallelize along the reduction axis with MPI reduce + ]) +def test_reduce_simple(sizes): + @dace + def program(x: dace.int64[16, 16]): + return np.add.reduce(x, axis=1) + + sdfg = program.to_sdfg() + sdfg.expand_library_nodes() + # don't expand the NSDFG, this test tests that + assert any( + isinstance(n, nodes.NestedSDFG) for state in sdfg.nodes() + for n in state.nodes()) + + reduce = find_map_containing(sdfg, "reduce_output") + init = find_map_containing(sdfg, "init") + + schedule.lower(sdfg, {reduce: sizes, init: sizes[:1]}) + + X = arange_with_size([16, 16]) + expected = X.copy().sum(axis=1) + + compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes)) + + +def test_nested_two_maps(): + @dace + def nested(x): + y = np.zeros_like(x, shape=(16, 16)) + for i, j, k in dace.map[0:16, 0:16, 0:32]: + with dace.tasklet: + inp << x[i, k, j] + out = inp + out >> y(1, lambda x, y: x + y)[i, j] + return y + + @dace + def program(x: dace.int64[16, 32, 16]): + return nested(x) + + sdfg = program.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(interstate.StateFusion) + sdfg.sdfg_list[1].apply_transformations_repeated(interstate.InlineSDFG) + sdfg.expand_library_nodes() + # don't expand the NSDFG, this test tests that + assert any( + isinstance(n, nodes.NestedSDFG) for state in sdfg.nodes() + for n in state.nodes()) + + elementwise = find_map_containing(sdfg, "test_nested_nested") + init = find_map_containing(sdfg, "full__map") + + schedule.lower(sdfg, {elementwise: [2, 2, 2], init: [2, 2]}) + sdfg.validate()