From e0bc48b92235029b89c3cf5a453e427defb79c86 Mon Sep 17 00:00:00 2001 From: Oliver Rausch Date: Fri, 19 Aug 2022 18:08:49 +0200 Subject: [PATCH] Support nested SDFGs in distributed lowering This allows us to support reductions with their intialization states. The idea is that nested SDFG are required to be schedule such that there is no communication within them. The user passes the schedules for each map, and the implied communication constraints are then checked for consistency. Keeping communication out of the Nested SDFGs means that there is no communication between things like reduction buffer initialization, and also means that all global communication is kept top-level, where it is easier to optimize Pull Request: https://github.com/spcl/daceml/pull/123 --- .github/workflows/cpu-ci.yml | 4 +- .github/workflows/gpu-ci.yml | 2 +- daceml/distributed/communication/subarrays.py | 25 +- daceml/distributed/schedule.py | 344 +++++++++++++++--- daceml/distributed/utils.py | 107 +++++- tests/distributed/test_elementwise.py | 59 +++ tests/distributed/test_lower_schedule.py | 93 ----- tests/distributed/test_nested.py | 73 ++++ 8 files changed, 547 insertions(+), 160 deletions(-) create mode 100644 tests/distributed/test_elementwise.py delete mode 100644 tests/distributed/test_lower_schedule.py create mode 100644 tests/distributed/test_nested.py diff --git a/.github/workflows/cpu-ci.yml b/.github/workflows/cpu-ci.yml index 52e81a97..8223308f 100644 --- a/.github/workflows/cpu-ci.yml +++ b/.github/workflows/cpu-ci.yml @@ -58,7 +58,7 @@ jobs: - name: Test with pytest env: ORT_RELEASE: ${{ github.workspace }}/onnxruntime-daceml-patched - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx" --timeout=500 + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx and not mpi" --timeout=500 run: make test - name: Test with doctest @@ -95,7 +95,7 @@ jobs: - name: Test with pytest env: - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx" --timeout=500 --skip-ort + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx and not mpi" --timeout=500 --skip-ort run: make test - name: Upload coverage diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 93282b78..f836a771 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -29,7 +29,7 @@ jobs: - name: Test with pytest env: - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc --gpu-only -m "not slow and not fpga and not xilinx and not onnx" --timeout=500 + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc --gpu-only -m "not slow and not fpga and not xilinx and not onnx and not mpi" --timeout=500 run: make test - name: Upload coverage diff --git a/daceml/distributed/communication/subarrays.py b/daceml/distributed/communication/subarrays.py index 633a499f..1948bcde 100644 --- a/daceml/distributed/communication/subarrays.py +++ b/daceml/distributed/communication/subarrays.py @@ -137,6 +137,7 @@ def compute_scatter_color(parent_grid_variables: List[symbolic.symbol], def try_construct_subarray( sdfg: SDFG, state: SDFGState, pgrid_name: str, global_desc: data.Data, subset: subsets.Range, grid_variables: List[symbolic.symbol], + scatter: bool, dry_run: bool) -> Optional[Tuple[str, str, Optional[str]]]: """ Try to convert the given end of the distributed memlet to a subarray, @@ -151,6 +152,7 @@ def try_construct_subarray( :param subset: The end of the distributed memlet to convert. :param grid_variables: The process grid corresponding to the computation to which this is either an input or an output. + :param scatter: True if this is for a scatter. :param dry_run: If True, don't actually create the grids and subarray. Instead, return None. :return: The name of the subarray, the name of the scatter grid and the @@ -182,13 +184,15 @@ def try_construct_subarray( bcast_shape = subgrid_shape(bcast_color) if not dry_run: - scatter_grid_name = sdfg.add_pgrid(shape=scatter_shape, - parent_grid=pgrid_name, - color=scatter_color) + scatter_grid_name = sdfg.add_pgrid( + shape=scatter_shape, + parent_grid=pgrid_name, + color=scatter_color, + exact_grid=None if scatter else 0) + bcast_grid_name = sdfg.add_pgrid(shape=bcast_shape, parent_grid=pgrid_name, color=bcast_color) - for name, shape in ((scatter_grid_name, scatter_shape), (bcast_grid_name, bcast_shape)): distr_utils.initialize_fields(state, [ @@ -280,6 +284,7 @@ def can_be_applied(self, state: SDFGState, *_, **__): garr, node.src_subset, src_vars, + False, dry_run=True) if node.dst_pgrid is not None: @@ -290,6 +295,7 @@ def can_be_applied(self, state: SDFGState, *_, **__): garr, node.dst_subset, dst_vars, + False, dry_run=True) except CommunicationSolverException: return False @@ -329,7 +335,14 @@ def expansion(node: 'DistributedMemlet', state: SDFGState, sdfg: SDFG): rvars = src_vars subarray_name, scatter_grid, bcast_grid = try_construct_subarray( - sdfg, state, pgrid_name, garr, subset, rvars, dry_run=False) + sdfg, + state, + pgrid_name, + garr, + subset, + rvars, + scatter, + dry_run=False) if scatter: expansion = mpi.BlockScatter(node.label, @@ -339,7 +352,7 @@ def expansion(node: 'DistributedMemlet', state: SDFGState, sdfg: SDFG): else: expansion = mpi.BlockGather(node.label, subarray_type=subarray_name, - gather_grid=pgrid_name, + gather_grid=scatter_grid, reduce_grid=bcast_grid) # clean up connectors to match the new node diff --git a/daceml/distributed/schedule.py b/daceml/distributed/schedule.py index a4d72e1a..e39b00c4 100644 --- a/daceml/distributed/schedule.py +++ b/daceml/distributed/schedule.py @@ -2,11 +2,14 @@ A distributed schedule for a subgraph is a mapping from each map entry node to process grid dimensions. """ +import copy import collections -from typing import List, Tuple, Dict, Set, Optional +from typing import List, Tuple, Dict, Set, Optional, Union import itertools -import sympy +import networkx as nx + +import sympy as sp from dace import nodes, SDFG, SDFGState, symbolic, subsets, memlet, data from dace.sdfg import propagation, utils as sdfg_utils @@ -59,7 +62,7 @@ def compute_tiled_map_range(nmap: nodes.Map, raise NotImplementedError("Cannot tile map with step") # check that we divide evenly - if sympy.Mod(exact_size, block_size).simplify() != 0: + if sp.Mod(exact_size, block_size).simplify() != 0: raise ValueError(f"{exact_size} is not divisible by {block_size}") td_to_new = exact_size // block_size - 1 td_step_new = td_step @@ -154,7 +157,11 @@ def propagate_rank_local_subsets( def rank_tile_map( - sdfg: SDFG, state: SDFGState, map_nodes: MapNodes, num_blocks: NumBlocks + sdfg: SDFG, + state: SDFGState, + map_nodes: MapNodes, + num_blocks: NumBlocks, + force_local_names: Optional[Dict[str, str]] = None, ) -> Tuple[RankVariables, GlobalToLocal, GlobalToLocal]: """ Tile the map according to the given block sizes, create rank-local views @@ -169,6 +176,7 @@ def rank_tile_map( :param state: The state to operate on. :param map_nodes: The map nodes to operate on. :param num_blocks: The number of blocks to tile the map with in each dimension. + :param force_local_names: A dictionary mapping global array names to local array names. :returns: created symbolic variables for each rank axis, and two lists of tuples associating global AccessNodes, local AccessNodes and the symbolic communication constraints. The first is for reads, the @@ -217,15 +225,18 @@ def rank_tile_map( elif isinstance(global_node.desc(sdfg), data.View): raise NotImplementedError("Cannot handle views yet") - # Create the rank-local array - local_name, _ = sdfg.add_transient( - name="local_" + global_name, - shape=[ - symbolic.overapproximate(r).simplify() - for r in new_subset.size_exact() - ], - dtype=sdfg.arrays[global_name].dtype, - find_new_name=True) + if force_local_names is not None and global_name in force_local_names: + # reuse descriptor + local_name = force_local_names[global_name] + else: + local_name, _ = sdfg.add_transient( + name="local_" + global_name, + shape=[ + symbolic.overapproximate(r).simplify() + for r in new_subset.size_exact() + ], + dtype=sdfg.arrays[global_name].dtype, + find_new_name=True) local_node = state.add_access(local_name) if is_input: @@ -247,6 +258,238 @@ def rank_tile_map( return rank_variables, result_read, result_write +def ordered_nodes_by_state( + sdfg: SDFG +) -> Dict[SDFGState, List[Union[nodes.Map, nodes.NestedSDFG]]]: + ordered_maps = {} + + for state in sdfg.nodes(): + top_level_nodes = set(state.scope_children()[None]) + result_nodes = [ + node.map if isinstance(node, nodes.MapEntry) else node + for node in sdfg_utils.dfs_topological_sort(state) + if isinstance(node, (nodes.MapEntry, + nodes.NestedSDFG)) and node in top_level_nodes + ] + ordered_maps[state] = result_nodes + return ordered_maps + + +def rank_tile_nested( + sdfg: SDFG, + state: SDFGState, + nnode: nodes.NestedSDFG, + schedule: DistributedSchedule, +) -> Tuple[NumBlocks, RankVariables, GlobalToLocal, GlobalToLocal]: + """ + Rank tile a nested SDFG. To make this viable in the cases we want to support + (typically just initialization for reductions) there are the following constraints: + + * Only a single ProcessGrid is used for the whole SDFG. + * No communication is added within the NSDFG (this enables fusability outside of it) + * As a result, schedules in the NSDFG must be 'consistent', in that if they + write to the same array, the communication constraints must be the same + """ + + nsdfg = nnode.sdfg + + # rule out unsqueezes + for edge in itertools.chain(state.in_edges(nnode), state.out_edges(nnode)): + if not utils.all_equal(edge.data.subset.size_exact(), + sdfg.arrays[edge.data.data].shape): + raise NotImplementedError( + "Cannot handle unsqueezes on NestedSDFG connectors") + + ordered_maps = ordered_nodes_by_state(nsdfg) + + rank_variables: Dict[str, int] = {} + num_fully_replicated: int = 0 + + constraints: Dict[str, subsets.Range] = {} + global_to_local: Dict[str, str] = {} + + # Rank tile each map, collecting the constraints by array + for nstate, map_nodes in ordered_maps.items(): + for map_node in map_nodes: + num_blocks = schedule[map_node] + + if len(num_blocks) != map_node.get_param_num(): + raise ValueError( + f"Schedule for {map_node} has {len(num_blocks)} " + f"block sizes, but {map_node.get_param_num()} are " + "required.") + + map_nodes = find_map_nodes(nstate, map_node) + + # modify the map to tile it + map_rank_variables, reads, writes = rank_tile_map( + nsdfg, + nstate, + map_nodes, + num_blocks, + force_local_names=global_to_local) + + # gather our current constraints, and delete the global access nodes + current_constraints = {} + for nglobal, nlocal, subset in reads: + current_constraints[nglobal.data] = subset + # update global mapping + global_to_local[nglobal.data] = nlocal.data + # delete global access node + nstate.remove_node(nglobal) + + for nglobal, nlocal, subset in writes: + # update global mapping + global_to_local[nglobal.data] = nlocal.data + if nglobal.data in current_constraints and current_constraints[ + nglobal.data] != subset: + raise ValueError( + "Inconsistent constraints found within map") + current_constraints[nglobal.data] = subset + nstate.remove_node(nglobal) + + # the tiled map now has its own rank variables. We now need to + # these variables consistent with the outer variables since the + # whole nested SDFG only has one ProcessGrid + me, mx = map_nodes + + renaming = {} + + map_block_size_per_symbol = { + s.name: bs + for s, bs in zip(map_rank_variables, num_blocks) + } + # the arrays that already have constraints from another map + for name, subset in current_constraints.items(): + if name not in constraints: + constraints[name] = subset + continue + + previous_constraint = constraints[name] + + is_symbol = lambda x: x.is_symbol + # try to match the constraints + + used_symbols = set(subset.free_symbols) + used_symbols |= set(previous_constraint.free_symbols) + + def new_symbol_name(name): + name = utils.find_str_not_in_set(used_symbols, name) + used_symbols.add(name) + return name + + pattern = copy.deepcopy(subset) + wilds = { + s: sp.Wild(new_symbol_name(f"i{i}"), + properties=[ + is_symbol, lambda x: rank_variables[x.name] + == map_block_size_per_symbol[s] + ]) + for i, s in enumerate(subset.free_symbols) + } + pattern.replace(wilds) + + def range_to_basic_expr(self): + exprs = [] + for (rb, re, rs), ts in zip(self.ranges, self.tile_sizes): + exprs.append(rb) + exprs.append(re) + exprs.append(ts) + return sp.Basic(*exprs) + + match = range_to_basic_expr(previous_constraint).match( + range_to_basic_expr(pattern)) + if match is None: + raise ValueError( + "Inconsistent constraints found within map") + + renaming.update(match) + + map_num_replicated = 0 + # now check that the block sizes are consistent + # The should already be consistent due to the sympy matching above, + # but let's make sure anyway + for v, bs in zip(map_rank_variables, num_blocks): + if v in renaming: + v = renaming[v] + + if v.name == node.FULLY_REPLICATED_RANK: + map_num_replicated += 1 + continue + + if v.name in rank_variables: + if rank_variables[v.name] != bs: + raise ValueError( + "Inconsistent block sizes found within NestedSDFG") + else: + rank_variables[v.name] = bs + num_fully_replicated = max(num_fully_replicated, + map_num_replicated) + + # Concretize process grid ordering + ordered_rank_variables = [] + block_sizes = [] + for ov, bs in rank_variables.items(): + ordered_rank_variables.append(ov) + block_sizes.append(bs) + ordered_rank_variables += [node.FULLY_REPLICATED_RANK + ] * num_fully_replicated + block_sizes += [1] * num_fully_replicated + + # Swap out the global connectors for local view connectors on the outside of the NSDFG + nnode.in_connectors = { + global_to_local[k]: v + for k, v in nnode.in_connectors.items() + } + nnode.out_connectors = { + global_to_local[k]: v + for k, v in nnode.out_connectors.items() + } + + # create outer transients for the local views + to_iter = itertools.chain( + zip(state.in_edges(nnode), itertools.repeat(True)), + zip(state.out_edges(nnode), itertools.repeat(False))) + + reads = [] + writes = [] + for edge, is_read in to_iter: + global_name = edge.dst_conn if is_read else edge.src_conn + local_name_in_nsdfg = global_to_local[global_name] + inner_desc = nsdfg.arrays[local_name_in_nsdfg] + outer_local_name = sdfg.add_datadesc( + name=local_name_in_nsdfg, + datadesc=copy.deepcopy(inner_desc), + find_new_name=True) + access = state.add_access(outer_local_name) + + if is_read: + reads.append((edge.src, access, constraints[global_name])) + redirect_args = dict(new_dst_conn=local_name_in_nsdfg, + new_src=access) + else: + writes.append((edge.dst, access, constraints[global_name])) + redirect_args = dict(new_src_conn=local_name_in_nsdfg, + new_dst=access) + + new_edge = xfh.redirect_edge(state, + edge, + new_data=outer_local_name, + **redirect_args) + new_edge.data.subset = constraints[global_name] + + for global_name, local_name in global_to_local.items(): + if not nsdfg.arrays[global_name].transient: + nsdfg.arrays[local_name].transient = False + del nsdfg.arrays[global_name] + + # specialize process grid vars inside the nsdfg + sdfg.specialize({v: 0 for v in ordered_rank_variables}) + ordered_rank_variables = list( + map(symbolic.pystr_to_symbolic, ordered_rank_variables)) + return block_sizes, ordered_rank_variables, reads, writes + + def lower(sdfg: SDFG, schedule: DistributedSchedule): """ Attempt to lower the SDFG to a SPMD MPI SDFG to distribute computation. @@ -257,39 +500,43 @@ def lower(sdfg: SDFG, schedule: DistributedSchedule): :param schedule: The schedule to use. :note: Operates in-place. """ - missing = set(distr_utils.all_top_level_maps(sdfg)).difference( schedule.keys()) + if missing: raise ValueError( - f"Missing schedule for maps {', '.join(map(str, missing))}") + f"Missing schedule for maps: {', '.join(map(lambda x: x.label, missing))}" + ) # Order the schedule topologically for each state - ordered_maps: Dict[SDFGState, List[nodes.Map]] = {} - - for state in sdfg.nodes(): - top_level_nodes = set(state.scope_children()[None]) - map_entries = [ - node.map for node in sdfg_utils.dfs_topological_sort(state) - if isinstance(node, nodes.MapEntry) and node in top_level_nodes - ] - ordered_maps[state] = map_entries + ordered_nodes = ordered_nodes_by_state(sdfg) # each map has a main process grid # with the dimension given by the schedule - maps_to_pgrids: Dict[nodes.Map, Tuple[RankVariables, str]] = {} - - for state, map_nodes in ordered_maps.items(): - for map_node in map_nodes: - num_blocks = schedule[map_node] - - if len(num_blocks) != map_node.get_param_num(): - raise ValueError( - f"Schedule for {map_node} has {len(num_blocks)} " - f"block sizes, but {map_node.get_param_num()} are " - "required.") - - map_nodes = find_map_nodes(state, map_node) + process_grids: Dict[str, RankVariables] = {} + + for state, top_level_nodes in ordered_nodes.items(): + for top_level_node in top_level_nodes: + if isinstance(top_level_node, nodes.NestedSDFG): + num_blocks, rank_variables, reads, writes = rank_tile_nested( + sdfg, state, top_level_node, schedule) + else: + map_node: nodes.Map = top_level_node + num_blocks = schedule[map_node] + + if len(num_blocks) != map_node.get_param_num(): + raise ValueError( + f"Schedule for {map_node} has {len(num_blocks)} " + f"block sizes, but {map_node.get_param_num()} are " + "required.") + + map_nodes = find_map_nodes(state, map_node) + + # modify the map to tile it + rank_variables, reads, writes = rank_tile_map( + sdfg, state, map_nodes, num_blocks) + rank_variable_names: List[str] = list( + map(lambda s: s.name, rank_variables)) # Create a process grid that will be used for communication process_grid_name = sdfg.add_pgrid(shape=num_blocks) @@ -303,27 +550,16 @@ def lower(sdfg: SDFG, schedule: DistributedSchedule): f'bool {process_grid_name}_valid;', ]) - # modify the map to tile it - rank_variables, reads, writes = rank_tile_map( - sdfg, state, map_nodes, num_blocks) - rank_variable_names: List[str] = list( - map(lambda s: s.name, rank_variables)) - - maps_to_pgrids[map_node] = rank_variables, process_grid_name + process_grids[process_grid_name] = rank_variables to_iter = itertools.chain(zip(reads, itertools.repeat(True)), zip(writes, itertools.repeat(False))) for (nglobal, nlocal, subset), is_read in to_iter: if not is_read: - # we need to check if this array was written before. - # If it was written before, and the current write is - # write-conflicted, we need to communicate the previously - # written values - if analysis.is_previously_written(sdfg, state, nlocal, - nglobal.data): - raise NotImplementedError( - "In place updates not supported yet") + # FIXME We don't need to communicate this if it is a + # derived schedule and one of our siblings comes after us. + pass full_subset = subsets.Range.from_array(nglobal.desc(sdfg)) @@ -355,11 +591,13 @@ def lower(sdfg: SDFG, schedule: DistributedSchedule): state.add_edge(comm, None, dst, None, sdfg.make_array_memlet(dst.data)) + sdfg.validate() utils.expand_nodes( sdfg, predicate=lambda n: isinstance(n, node.DistributedMemlet)) + sdfg.validate() # Now that we are done lowering, we can instatiate the process grid # variables with zero, since each rank only sees its section of the array - for _, (variables, _) in maps_to_pgrids.items(): + for _, variables in process_grids.items(): repl_dict = {v.name: 0 for v in variables} sdfg.specialize(repl_dict) diff --git a/daceml/distributed/utils.py b/daceml/distributed/utils.py index 7608b657..e92560e0 100644 --- a/daceml/distributed/utils.py +++ b/daceml/distributed/utils.py @@ -1,8 +1,13 @@ -from typing import List, Iterator +from typing import List, Iterator, Tuple, Union, Dict + +import numpy as np +import pytest import dace from dace import SDFG, SDFGState, nodes, dtypes +from dace.sdfg import utils as sdfg_utils from dace.libraries import mpi +from daceml.util import utils def initialize_fields(state: SDFGState, fields: List[str]): @@ -24,11 +29,20 @@ def initialize_fields(state: SDFGState, fields: List[str]): dace.Memlet.from_array(dummy_name, scal)) -def all_top_level_maps(sdfg: SDFG) -> Iterator[nodes.Map]: +def all_top_level_maps( + sdfg: SDFG, + yield_parent=False +) -> Iterator[Union[nodes.Map, Tuple[nodes.Map, SDFGState]]]: for state in sdfg.nodes(): for node in state.scope_children()[None]: if isinstance(node, nodes.MapEntry): - yield node.map + if yield_parent: + yield node.map, state + else: + yield node.map + elif isinstance(node, nodes.NestedSDFG): + # also check top-level nested SDFGs + yield from all_top_level_maps(node.sdfg, yield_parent) def add_debug_rank_cords_tasklet(sdfg: SDFG): @@ -37,7 +51,10 @@ def add_debug_rank_cords_tasklet(sdfg: SDFG): """ new_state = sdfg.add_state_before(sdfg.start_state, 'debug_MPI') - code = "{\n" + code = """{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + """ for grid_name, desc in sdfg.process_grids.items(): code += """ @@ -49,7 +66,7 @@ def add_debug_rank_cords_tasklet(sdfg: SDFG): printf("\\n"); }} - printf("Hello from rank %d in grid {grid_name}, with coords: ", __state->{grid_name}_rank); + printf("Hello from global rank %d, rank %d in grid {grid_name}, with coords: ", rank, __state->{grid_name}_rank); for (int i = 0; i < {grid_dims}; i++) {{ printf("%d ", __state->{grid_name}_coords[i]); @@ -72,3 +89,83 @@ def add_debug_rank_cords_tasklet(sdfg: SDFG): wnode = new_state.add_write(dummy_name) new_state.add_edge(tasklet, '__out', wnode, None, dace.Memlet.from_array(dummy_name, scal)) + + +def add_debugprint_tasklet(sdfg: SDFG, state: SDFGState, + node: nodes.AccessNode): + """ + Insert a tasklet that just prints out the given data + """ + desc = node.desc(sdfg) + loops = "\n".join("for (int i{v} = 0; i{v} < {s}; i{v}++)".format(v=i, s=s) + for i, s in enumerate(desc.shape)) + code = """ + {loops} + {{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + printf("RANK %d: {name}[%d] = %d\\n", rank, {indices}, {name}[{indices}]); + }} + """.format(name=node.data, + loops=loops, + indices=",".join( + ["i{}".format(i) for i in range(len(desc.shape))])) + + # add a tasklet that writes to nothing + tasklet = nodes.Tasklet("print_" + node.data, {"__in"}, {"__out"}, code, + dtypes.Language.CPP) + + state.add_node(tasklet) + state.add_edge(node, None, tasklet, '__in', + sdfg.make_array_memlet(node.data)) + + # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. + dummy_name, scal = sdfg.add_scalar("dummy", + dace.int32, + transient=True, + find_new_name=True) + state.add_edge(tasklet, '__out', state.add_write(dummy_name), None, + dace.Memlet.from_array(dummy_name, scal)) + + +def arange_with_size(size): + return np.arange(utils.prod(size), dtype=np.int64).reshape(size).copy() + + +def find_map_containing(sdfg, name) -> nodes.Map: + cands = [] + for node, state in all_top_level_maps(sdfg, yield_parent=True): + if name in node.label: + cands.append(node) + if len(cands) == 1: + return cands[0] + else: + raise ValueError("Found {} candidates for map name {}".format( + len(cands), name)) + + +def compile_and_call(sdfg, inputs: Dict[str, np.ndarray], + expected_output: np.ndarray, num_required_ranks: int): + MPI = pytest.importorskip("mpi4py.MPI") + commworld = MPI.COMM_WORLD.Dup() + commworld.Barrier() + rank = commworld.Get_rank() + size = commworld.Get_size() + + if size < num_required_ranks: + raise ValueError( + "This test requires at least {} ranks".format(num_required_ranks)) + + func = sdfg_utils.distributed_compile(sdfg, commworld) + + if rank == 0: + result = func(**inputs) + np.testing.assert_allclose(result, expected_output) + else: + dummy_inputs = { + k: np.zeros_like(v, shape=(1, )) + for k, v in inputs.items() + } + func(**dummy_inputs) + commworld.Barrier() + commworld.Free() diff --git a/tests/distributed/test_elementwise.py b/tests/distributed/test_elementwise.py new file mode 100644 index 00000000..f9985b8e --- /dev/null +++ b/tests/distributed/test_elementwise.py @@ -0,0 +1,59 @@ +""" +These tests expect to be with 4 ranks +""" +import pytest + +import dace + +from daceml.util import utils +from daceml.distributed import schedule +from daceml.distributed.utils import find_map_containing, compile_and_call, arange_with_size + + +@pytest.mark.parametrize("sizes", [ + [2], + [4], +]) +def test_elementwise_1d(sizes): + @dace + def program(x: dace.int64[64]): + return x + 5 + + sdfg = program.to_sdfg() + + map_entry = find_map_containing(sdfg, "") + schedule.lower(sdfg, {map_entry: sizes}) + + X = arange_with_size([64]) + expected = X + 5 + compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes)) + + +@pytest.mark.parametrize( + "sizes", + [ + [2, 1, 1], # fully replicate broadcasted => use 2d broadcast grid + [2, 2, 1], # no broadcast grid, 1d scatter grid + [2, 2, 2], # no broadcast grid, 2d scatter grid + [1, 2, 1], + ]) +def test_bcast_simple(sizes): + @dace + def program(x: dace.int64[4, 8, 16], y: dace.int64[8, 16]): + return x + y + + sdfg = program.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + map_entry = find_map_containing(sdfg, "") + schedule.lower(sdfg, {map_entry: sizes}) + + X = arange_with_size([4, 8, 16]) + Y = arange_with_size([8, 16]) + expected = X + Y + + compile_and_call(sdfg, { + 'x': X.copy(), + 'y': Y.copy() + }, expected, utils.prod(sizes)) diff --git a/tests/distributed/test_lower_schedule.py b/tests/distributed/test_lower_schedule.py deleted file mode 100644 index 0f8f7501..00000000 --- a/tests/distributed/test_lower_schedule.py +++ /dev/null @@ -1,93 +0,0 @@ -""" -These tests expect to be with 4 ranks -""" -import pytest -import numpy as np - -import dace -from dace.sdfg import utils as sdfg_utils - -from daceml.util import utils -from daceml.distributed import schedule, utils as distr_utils - -MPI = pytest.importorskip("mpi4py.MPI") - - -def arange_with_size(size): - return np.arange(utils.prod(size), dtype=np.int64).reshape(size).copy() - - -@pytest.mark.parametrize("sizes", [ - [2], - [4], -]) -def test_elementwise_1d(sizes): - assert utils.prod(sizes) <= 4 - - @dace - def program(x: dace.int64[64]): - return x + 5 - - sdfg = program.to_sdfg() - - map_entry = [n for n in sdfg.node(0).scope_children().keys() if n][0] - schedule.lower(sdfg, {map_entry.map: sizes}) - - X = arange_with_size([64]) - expected = X + 5 - - commworld = MPI.COMM_WORLD - rank = commworld.Get_rank() - size = commworld.Get_size() - if size < utils.prod(sizes): - raise ValueError("This test requires at least {} ranks".format( - utils.prod(sizes))) - - func = sdfg_utils.distributed_compile(sdfg, commworld) - - if rank == 0: - result = func(x=X.copy()) - np.testing.assert_allclose(result, expected) - else: - func(x=np.zeros((1, ), dtype=np.int64)) - - -@pytest.mark.parametrize( - "sizes", - [ - [2, 1, 1], # fully replicate broadcasted => use 2d broadcast grid - [2, 2, 1], # no broadcast grid, 1d scatter grid - [2, 2, 2], # no broadcast grid, 2d scatter grid - [1, 2, 1], - ]) -def test_bcast_simple(sizes): - @dace - def program(x: dace.int64[4, 8, 16], y: dace.int64[8, 16]): - return x + y - - sdfg = program.to_sdfg() - sdfg.expand_library_nodes() - sdfg.simplify() - - map_entry = [n for n in sdfg.node(0).scope_children().keys() if n][0] - schedule.lower(sdfg, {map_entry.map: sizes}) - - X = arange_with_size([4, 8, 16]) - Y = arange_with_size([8, 16]) - expected = X + Y - - commworld = MPI.COMM_WORLD - rank = commworld.Get_rank() - size = commworld.Get_size() - if size < utils.prod(sizes): - raise ValueError("This test requires at least {} ranks".format( - utils.prod(sizes))) - - func = sdfg_utils.distributed_compile(sdfg, commworld) - - if rank == 0: - result = func(x=X.copy(), y=Y.copy()) - np.testing.assert_allclose(result, expected) - else: - func(x=np.zeros((1, ), dtype=np.int64), - y=np.zeros((1, ), dtype=np.int64)) diff --git a/tests/distributed/test_nested.py b/tests/distributed/test_nested.py new file mode 100644 index 00000000..cb4c4bbc --- /dev/null +++ b/tests/distributed/test_nested.py @@ -0,0 +1,73 @@ +import pytest +import numpy as np + +import dace +from dace import nodes +from dace.transformation import interstate +from dace.sdfg import utils as sdfg_utils + +from daceml.util import utils +from daceml.distributed import schedule +from daceml.distributed.utils import find_map_containing, compile_and_call, arange_with_size + + +@pytest.mark.parametrize( + "sizes", + [ + [1, 1], + [2, 1], + [2, 2], # parallelize along the reduction axis with MPI reduce + [2, 4], # parallelize along the reduction axis with MPI reduce + ]) +def test_reduce_simple(sizes): + @dace + def program(x: dace.int64[16, 16]): + return np.add.reduce(x, axis=1) + + sdfg = program.to_sdfg() + sdfg.expand_library_nodes() + # don't expand the NSDFG, this test tests that + assert any( + isinstance(n, nodes.NestedSDFG) for state in sdfg.nodes() + for n in state.nodes()) + + reduce = find_map_containing(sdfg, "reduce_output") + init = find_map_containing(sdfg, "init") + + schedule.lower(sdfg, {reduce: sizes, init: sizes[:1]}) + + X = arange_with_size([16, 16]) + expected = X.copy().sum(axis=1) + + compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes)) + + +def test_nested_two_maps(): + @dace + def nested(x): + y = np.zeros_like(x, shape=(16, 16)) + for i, j, k in dace.map[0:16, 0:16, 0:32]: + with dace.tasklet: + inp << x[i, k, j] + out = inp + out >> y(1, lambda x, y: x + y)[i, j] + return y + + @dace + def program(x: dace.int64[16, 32, 16]): + return nested(x) + + sdfg = program.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(interstate.StateFusion) + sdfg.sdfg_list[1].apply_transformations_repeated(interstate.InlineSDFG) + sdfg.expand_library_nodes() + # don't expand the NSDFG, this test tests that + assert any( + isinstance(n, nodes.NestedSDFG) for state in sdfg.nodes() + for n in state.nodes()) + + elementwise = find_map_containing(sdfg, "test_nested_nested") + init = find_map_containing(sdfg, "full__map") + + schedule.lower(sdfg, {elementwise: [2, 2, 2], init: [2, 2]}) + sdfg.validate()