diff --git a/.github/workflows/cpu-ci.yml b/.github/workflows/cpu-ci.yml index 52e81a97..8223308f 100644 --- a/.github/workflows/cpu-ci.yml +++ b/.github/workflows/cpu-ci.yml @@ -58,7 +58,7 @@ jobs: - name: Test with pytest env: ORT_RELEASE: ${{ github.workspace }}/onnxruntime-daceml-patched - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx" --timeout=500 + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx and not mpi" --timeout=500 run: make test - name: Test with doctest @@ -95,7 +95,7 @@ jobs: - name: Test with pytest env: - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx" --timeout=500 --skip-ort + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc -m "not fpga and not xilinx and not gpu and not onnx and not mpi" --timeout=500 --skip-ort run: make test - name: Upload coverage diff --git a/.github/workflows/dist-ci.yml b/.github/workflows/dist-ci.yml new file mode 100644 index 00000000..d4f5758e --- /dev/null +++ b/.github/workflows/dist-ci.yml @@ -0,0 +1,38 @@ +name: Distributed CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + test-distributed: + if: ${{ !contains(github.event.pull_request.labels.*.name, 'no-ci') }} + runs-on: [self-hosted, linux, gpu] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + submodules: 'recursive' + + - name: Install dependencies + env: + UPDATE_PIP: 'true' + run: | + rm -rf .dacecache tests/.dacecache + . /opt/setupenv + make clean install + venv/bin/pip install mpi4py + + - name: Run Distributed Tests + env: + PYTEST: venv/bin/coverage run --parallel-mode --source=daceml -m pytest + PYTEST_ARGS: -s + PYTEST_PLUGINS: tests.distributed.mpi_mute + MPI_PREFIX: mpirun -np 8 --oversubscribe + run: make test-distributed + + - name: Upload coverage + run: make codecov diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 93282b78..f836a771 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -29,7 +29,7 @@ jobs: - name: Test with pytest env: - PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc --gpu-only -m "not slow and not fpga and not xilinx and not onnx" --timeout=500 + PYTEST_ARGS: --cov=daceml --cov-report=term --cov-report xml --cov-config=.coveragerc --gpu-only -m "not slow and not fpga and not xilinx and not onnx and not mpi" --timeout=500 run: make test - name: Upload coverage diff --git a/Makefile b/Makefile index 99f92cfd..7e84b30c 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ VENV_PATH ?= venv PYTHON ?= python PYTHON_BINARY ?= python PYTEST ?= pytest +MPI_PREFIX ?= mpirun PIP ?= pip YAPF ?= yapf @@ -17,13 +18,6 @@ ACTIVATE = . $(VENV_PATH)/bin/activate && endif .PHONY: clean doc doctest test test-gpu codecov check-formatting check-formatting-names clean-dacecaches yapf -clean: - ! test -d $(VENV_PATH) || rm -r $(VENV_PATH) - -venv: -ifneq ($(VENV_PATH),) - test -d $(VENV_PATH) || echo "Creating new venv" && $(PYTHON) -m venv ./$(VENV_PATH) -endif install: venv ifneq ($(VENV_PATH),) @@ -35,6 +29,14 @@ endif $(ACTIVATE) $(PIP) install $(TORCH_VERSION) $(ACTIVATE) $(PIP) install -e .[testing,docs] +clean: + ! test -d $(VENV_PATH) || rm -r $(VENV_PATH) + +venv: +ifneq ($(VENV_PATH),) + test -d $(VENV_PATH) || echo "Creating new venv" && $(PYTHON) -m venv ./$(VENV_PATH) +endif + doc: # suppress warnings in ONNXOps docstrings using grep -v $(ACTIVATE) cd doc && make clean html 2>&1 \ @@ -62,6 +64,9 @@ test-intel-fpga: test-xilinx: $(ACTIVATE) $(PYTEST) $(PYTEST_ARGS) tests/torch/fpga/ +test-distributed: + $(ACTIVATE) $(MPI_PREFIX) $(PYTEST) $(PYTEST_ARGS) tests/distributed + codecov: curl -s https://codecov.io/bash | bash diff --git a/daceml/distributed/__init__.py b/daceml/distributed/__init__.py new file mode 100644 index 00000000..63d78206 --- /dev/null +++ b/daceml/distributed/__init__.py @@ -0,0 +1 @@ +from . import schedule diff --git a/daceml/distributed/communication/node.py b/daceml/distributed/communication/node.py new file mode 100644 index 00000000..ff04afd6 --- /dev/null +++ b/daceml/distributed/communication/node.py @@ -0,0 +1,125 @@ +import dace +import dace.library +from dace import nodes, properties, SDFG, SDFGState, subsets + +from . import subarrays + +from daceml.util import utils + +#: a placeholder variable that is used for all fully replicated ranks +FULLY_REPLICATED_RANK = 'FULLY_REPLICATED_RANK' + + +@dace.library.node +class DistributedMemlet(nodes.LibraryNode): + """ + Communication node that distributes input to output based on symbolic + expressions for the rank-local subsets. + + These expressions (``src_subset`` and ``dst_subset``) should include + symbolic variables given in ``src_rank_variables`` and + ``dst_rank_variables``. + """ + + # Global properties + implementations = { + "subarrays": subarrays.CommunicateSubArrays, + } + default_implementation = "subarrays" + + src_rank_variables = properties.ListProperty( + element_type=str, + desc="List of variables used in the indexing expressions that represent" + " rank identifiers in the source expression") + src_pgrid = properties.Property(dtype=str, allow_none=True) + src_subset = properties.RangeProperty( + default=subsets.Range([]), + desc="Subset of the input array that is held on each rank") + src_global_array = properties.DataProperty() + + dst_rank_variables = properties.ListProperty( + element_type=str, + desc="List of variables used in the indexing expressions that represent" + "rank identifiers in the destination expression") + dst_pgrid = properties.Property(dtype=str, allow_none=True) + dst_subset = properties.RangeProperty( + default=subsets.Range([]), + desc="Subset of the output array that is held on each rank") + dst_global_array = properties.DataProperty() + + def __init__(self, name, src_rank_variables, src_pgrid, src_subset, + src_global_array, dst_rank_variables, dst_pgrid, dst_subset, + dst_global_array): + super().__init__(name) + self.src_rank_variables = src_rank_variables + self.src_pgrid = src_pgrid + self.src_subset = src_subset + self.src_global_array = src_global_array + self.dst_rank_variables = dst_rank_variables + self.dst_pgrid = dst_pgrid + self.dst_subset = dst_subset + self.dst_global_array = dst_global_array + + def validate(self, sdfg: SDFG, state: SDFGState): + + if self.src_global_array not in sdfg.arrays: + raise ValueError( + f"{self.src_global_array} is not an array in the SDFG") + if self.dst_global_array not in sdfg.arrays: + raise ValueError( + f"{self.dst_global_array} is not an array in the SDFG") + + src_free_vars = self.src_subset.free_symbols + dst_free_vars = self.dst_subset.free_symbols + + if self.src_pgrid is None and self.dst_pgrid is None: + raise ValueError("At least one process grid must be specified") + + if src_free_vars.difference(self.src_rank_variables): + raise ValueError( + "Source subset has free variables that are not rank variables") + if dst_free_vars.difference(self.dst_rank_variables): + raise ValueError( + "Destination subset has free variables that are not rank " + "variables") + + if FULLY_REPLICATED_RANK in src_free_vars or FULLY_REPLICATED_RANK in dst_free_vars: + raise RuntimeError( + "Fully replicated rank appeared in free variables, this should" + " not happen") + + inp_buffer, out_buffer = None, None + + if state.out_degree(self) != 1: + raise ValueError( + "SymbolicCommunication node must have exactly one output edge") + if state.in_degree(self) != 1: + raise ValueError( + "SymbolicCommunication node must have exactly one input edge") + out_buffer = sdfg.arrays[state.out_edges(self)[0].data.data] + inp_buffer = sdfg.arrays[state.in_edges(self)[0].data.data] + if inp_buffer.dtype != out_buffer.dtype: + raise ValueError( + "Input and output buffers must have the same data type") + + # Check that subset sizes are correct + if not utils.all_equal(self.src_subset.size_exact(), inp_buffer.shape): + raise ValueError( + f"Source subset size {self.src_subset.size_exact()} does not" + f" match input buffer size {inp_buffer.shape}") + if not utils.all_equal(self.dst_subset.size_exact(), out_buffer.shape): + raise ValueError( + f"Destination subset size {self.dst_subset.size_exact()} does" + f" not match output buffer size {out_buffer.shape}") + + # Check process grids + if self.src_pgrid and self.src_pgrid not in sdfg.process_grids: + raise ValueError("Source process grid not found") + if self.dst_pgrid and self.dst_pgrid not in sdfg.process_grids: + raise ValueError("Destination process grid not found") + + return inp_buffer, out_buffer + + def __str__(self): + return (f"{self.src_global_array}[{self.src_subset}] ->" + f" {self.dst_global_array}[{self.dst_subset}]") diff --git a/daceml/distributed/communication/subarrays.py b/daceml/distributed/communication/subarrays.py new file mode 100644 index 00000000..1948bcde --- /dev/null +++ b/daceml/distributed/communication/subarrays.py @@ -0,0 +1,368 @@ +import typing +from typing import Optional, List, Tuple + +import sympy as sp + +import dace.library +from dace import SDFG, SDFGState, subsets, symbolic, data +from dace.transformation import transformation as pm +from dace.libraries import mpi + +from .. import utils as distr_utils + +if typing.TYPE_CHECKING: + from .node import DistributedMemlet + + +class CommunicationSolverException(Exception): + """ + Exception raised when the communication solver fails to find a solution. + """ + pass + + +MatchedDimensions = List[Tuple[Optional[symbolic.symbol], symbolic.symbol]] + + +def match_subset_axis_to_pgrid( + subset: subsets.Range, + grid_variables: List[symbolic.symbol]) -> MatchedDimensions: + """ + Matches each axis of the subset to a grid variable if that axis is tiled + using that grid variable. + + For example, an axis i*32:(i+1)*32 is matched to the grid variable i with + block size 32, and would add (i, 32) to the result. + + If a dimension is not tiled by any variable, the result will contain a None + for that dimension. + + :param subset: The subset to match. + :param grid_variables: The grid variables to match against. + :raises CommunicationSolverException: If matching fails + :return: A list of tuples, each containing the grid variable and the + matched block size. + """ + + # avoid import loop + from .node import FULLY_REPLICATED_RANK + + blocks = [] + sizes = subset.size_exact() + for i, (start, end, step) in enumerate(subset): + if step != 1: + raise CommunicationSolverException("Subset step must be 1") + bs = sp.Wild("bs", exclude=grid_variables) + + expr = sp.Basic(*map(symbolic.pystr_to_symbolic, (start, end))) + if not expr.free_symbols: + # no free symbols; this is a constant range and needs to be + # replicated to every rank + blocks.append((None, sizes[i])) + continue + + for p in grid_variables: + if p.name == FULLY_REPLICATED_RANK: + # this is not a valid variable to match + continue + + # try to match with this block variable + pattern = sp.Basic(p * bs, (p + 1) * bs - 1) + + matches = expr.match(pattern) + if matches is not None: + break + else: + # couldn't detect a block: exit + raise CommunicationSolverException( + "Could not match subset axis {} to grid variables {}".format( + subset[i], ', '.join(map(str, grid_variables)))) + + blocks.append((p, matches[bs])) + + return blocks + + +def compute_scatter_color(parent_grid_variables: List[symbolic.symbol], + parent_grid_shape: List[int], + matched_dimensions: MatchedDimensions) -> List[bool]: + # avoid import loop + from .node import FULLY_REPLICATED_RANK + + # We need to setup a broadcast grid to make sure the ranks on the + # remaining dimensions of the scatter grid get their values. We will + # split the process grid into two subgrids to achieve this + + dim_to_idx = { + v: i + for i, v in enumerate(parent_grid_variables) + if v.name != FULLY_REPLICATED_RANK + } + + # these are the dimensions we need to scatter our data over (i.e. not replicated) + scattered_dims = { + dim_to_idx[s] + for s, _ in matched_dimensions if s is not None + } + + # these are the number of unpartitioned dimensions in our subset + # these need to be mapped to a dimension of the process grid with size one + required_full_replication_dims = len(matched_dimensions) - len( + scattered_dims) + + # empty dims are the the indices of dims of size 1 in the global pgrid + empty_dims = { + i + for i, s in enumerate(parent_grid_variables) + if s.name == FULLY_REPLICATED_RANK + } + assert all(parent_grid_shape[i] == 1 for i in empty_dims) + + # the indices of size 1 dimensions we choose for the scatter grid + empty_scatter_dims = set() + for i in range(required_full_replication_dims): + # choose any empty rank + parent_empty_dim = next(iter(empty_dims)) + empty_dims.remove(parent_empty_dim) + scattered_dims.add(parent_empty_dim) + empty_scatter_dims.add(parent_empty_dim) + + # scatter_color[i] == True iff the i'th dimension is kept in the scatter grid + scatter_color = [ + i in scattered_dims for i in range(len(parent_grid_shape)) + ] + return scatter_color + + +def try_construct_subarray( + sdfg: SDFG, state: SDFGState, pgrid_name: str, global_desc: data.Data, + subset: subsets.Range, grid_variables: List[symbolic.symbol], + scatter: bool, + dry_run: bool) -> Optional[Tuple[str, str, Optional[str]]]: + """ + Try to convert the given end of the distributed memlet to a subarray, + creating the necessary process grids. Returns the name of the subarray, the + name of the scatter grid and the name of the bcast/reduction grid. + + :param sdfg: The SDFG to operate on. + :param state: The state to operate on (Dummy tasklets will be inserted + here) + :param pgrid_name: The name of the pgrid to use. + :param global_desc: The global data descriptor of the array to communicate. + :param subset: The end of the distributed memlet to convert. + :param grid_variables: The process grid corresponding to the computation to + which this is either an input or an output. + :param scatter: True if this is for a scatter. + :param dry_run: If True, don't actually create the grids and subarray. + Instead, return None. + :return: The name of the subarray, the name of the scatter grid and the + name of the bcast grid (possibly ``None`` if no broadcast is + necessary). + """ + + pgrid = sdfg.process_grids[pgrid_name] + + matched_dimensions = match_subset_axis_to_pgrid(subset, grid_variables) + + if len(matched_dimensions) > len(pgrid.shape): + # haven't thought about this yet + raise NotImplementedError() + elif len(matched_dimensions) < len(pgrid.shape): + + scatter_color = compute_scatter_color(grid_variables, pgrid.shape, + matched_dimensions) + assert sum(scatter_color) == subset.dims() + # The broadcast grid provides the data to the ranks that are not + # covered in the scatter grid. + # It's always the inverse of the scatter grid. + bcast_color = list(map(lambda x: not x, scatter_color)) + + subgrid_shape = lambda color: [ + s for s, c in zip(pgrid.shape, color) if c + ] + scatter_shape = subgrid_shape(scatter_color) + bcast_shape = subgrid_shape(bcast_color) + + if not dry_run: + scatter_grid_name = sdfg.add_pgrid( + shape=scatter_shape, + parent_grid=pgrid_name, + color=scatter_color, + exact_grid=None if scatter else 0) + + bcast_grid_name = sdfg.add_pgrid(shape=bcast_shape, + parent_grid=pgrid_name, + color=bcast_color) + for name, shape in ((scatter_grid_name, scatter_shape), + (bcast_grid_name, bcast_shape)): + distr_utils.initialize_fields(state, [ + f'MPI_Comm {name}_comm;', + f'MPI_Group {name}_group;', + f'int {name}_coords[{len(shape)}];', + f'int {name}_dims[{len(shape)}];', + f'int {name}_rank;', + f'int {name}_size;', + f'bool {name}_valid;', + ]) + current_idx = 0 + symbol_to_idx = {} + for i, s in enumerate(grid_variables): + if scatter_color[i]: + symbol_to_idx[s] = current_idx + current_idx += 1 + else: + scatter_shape = pgrid.shape + scatter_grid_name = pgrid_name + bcast_grid_name = None + symbol_to_idx = {s: i for i, s in enumerate(grid_variables)} + + # The indices of dims of size 1 in the scatter_grid + empty_dims = {i for i, s in enumerate(scatter_shape) if s == 1} + + # Now we need to map dimensions of the subset to the dimensions in the + # scatter grid + correspondence: List[int] = [] + + for i, (p, bs) in enumerate(matched_dimensions): + if p is None: + # we can choose to map to any of the empty ranks + chosen = next(iter(empty_dims)) + empty_dims.remove(chosen) + correspondence.append(chosen) + else: + grid_index = symbol_to_idx[p] + + global_block_size = global_desc.shape[i] / pgrid.shape[grid_index] + if global_block_size % bs != 0: + return None + + if bs != global_block_size: + raise CommunicationSolverException( + f"Detected block size {bs} (on axis {subset[i]}) does not" + "match global block size {global_block_size}") + correspondence.append(grid_index) + + assert all(map(lambda i: 0 <= i < len(scatter_shape), correspondence)) + + if not dry_run: + # create subarray + subarray_name = sdfg.add_subarray(dtype=global_desc.dtype, + shape=global_desc.shape, + subshape=subset.size_exact(), + pgrid=scatter_grid_name, + correspondence=correspondence) + distr_utils.initialize_fields(state, [ + f'MPI_Datatype {subarray_name};', f'int* {subarray_name}_counts;', + f'int* {subarray_name}_displs;' + ]) + + return subarray_name, scatter_grid_name, bcast_grid_name + else: + return None + + +class CommunicateSubArrays(pm.ExpandTransformation): + + environments = [] + + def can_be_applied(self, state: SDFGState, *_, **__): + node: 'DistributedMemlet' = state.node( + self.subgraph[type(self)._match_node]) + sdfg = state.parent + + src_vars = list( + map(symbolic.pystr_to_symbolic, node.src_rank_variables)) + dst_vars = list( + map(symbolic.pystr_to_symbolic, node.dst_rank_variables)) + + try: + if node.src_pgrid is not None: + garr = sdfg.arrays[node.src_global_array] + try_construct_subarray(sdfg, + state, + node.src_pgrid, + garr, + node.src_subset, + src_vars, + False, + dry_run=True) + + if node.dst_pgrid is not None: + garr = sdfg.arrays[node.dst_global_array] + try_construct_subarray(sdfg, + state, + node.dst_pgrid, + garr, + node.dst_subset, + dst_vars, + False, + dry_run=True) + except CommunicationSolverException: + return False + + return True + + @staticmethod + def expansion(node: 'DistributedMemlet', state: SDFGState, sdfg: SDFG): + + src_desc, dst_desc = node.validate(sdfg, state) + + src_vars = list( + map(symbolic.pystr_to_symbolic, node.src_rank_variables)) + dst_vars = list( + map(symbolic.pystr_to_symbolic, node.dst_rank_variables)) + + # There are 3 cases: + + if node.src_pgrid is not None and node.dst_pgrid is not None: + # 1. src and dst both have pgrids + raise NotImplementedError() + + elif node.src_pgrid is not None or node.dst_pgrid is not None: + # 2. only one of the two has a pgrid + # in this case we emit a BlockScatter or BlockGather + scatter = node.dst_pgrid is not None + + if scatter: + pgrid_name = node.dst_pgrid + garr = sdfg.arrays[node.dst_global_array] + subset = node.dst_subset + rvars = dst_vars + else: + pgrid_name = node.src_pgrid + garr = sdfg.arrays[node.src_global_array] + subset = node.src_subset + rvars = src_vars + + subarray_name, scatter_grid, bcast_grid = try_construct_subarray( + sdfg, + state, + pgrid_name, + garr, + subset, + rvars, + scatter, + dry_run=False) + + if scatter: + expansion = mpi.BlockScatter(node.label, + subarray_type=subarray_name, + scatter_grid=scatter_grid, + bcast_grid=bcast_grid) + else: + expansion = mpi.BlockGather(node.label, + subarray_type=subarray_name, + gather_grid=scatter_grid, + reduce_grid=bcast_grid) + + # clean up connectors to match the new node + expansion.add_in_connector("_inp_buffer") + expansion.add_out_connector("_out_buffer") + state.in_edges(node)[0].dst_conn = "_inp_buffer" + state.out_edges(node)[0].src_conn = "_out_buffer" + + return expansion + else: + # both have no pgrid + # this should just be a copy (?) + raise NotImplementedError() diff --git a/daceml/distributed/schedule.py b/daceml/distributed/schedule.py new file mode 100644 index 00000000..e39b00c4 --- /dev/null +++ b/daceml/distributed/schedule.py @@ -0,0 +1,603 @@ +""" +A distributed schedule for a subgraph is a mapping from each +map entry node to process grid dimensions. +""" +import copy +import collections +from typing import List, Tuple, Dict, Set, Optional, Union +import itertools + +import networkx as nx + +import sympy as sp + +from dace import nodes, SDFG, SDFGState, symbolic, subsets, memlet, data +from dace.sdfg import propagation, utils as sdfg_utils +from dace.transformation import helpers as xfh + +from daceml.autodiff import analysis +from daceml.util import utils + +NumBlocks = List[int] +DistributedSchedule = Dict[nodes.Map, NumBlocks] +MapNodes = Tuple[nodes.MapEntry, nodes.MapExit] + +from .communication import node +from . import utils as distr_utils + + +def find_map_nodes(state: SDFGState, map_node: nodes.Map) -> MapNodes: + found = [ + node for node in state.nodes() + if isinstance(node, (nodes.MapEntry, + nodes.MapExit)) and node.map == map_node + ] + + assert len( + found + ) == 2, "Found {} map scope nodes for map {}, expected exactly 2".format( + len(found), map_node) + + x, y = found + if isinstance(x, nodes.MapEntry): + assert isinstance( + y, nodes.MapExit), f"Found two entry nodes for map {map_node}" + return x, y + else: + assert isinstance( + y, nodes.MapEntry), f"Found two exit nodes for map {map_node}" + return y, x + + +def compute_tiled_map_range(nmap: nodes.Map, + num_blocks: NumBlocks) -> subsets.Range: + """ + Compute the range of the map after rank-tiling it with num_blocks + """ + new_ranges = [] + exact_sizes = nmap.range.size_exact() + for (td_from, td_to, td_step), block_size, exact_size in utils.strict_zip( + nmap.range, num_blocks, exact_sizes): + if td_step != 1: + raise NotImplementedError("Cannot tile map with step") + + # check that we divide evenly + if sp.Mod(exact_size, block_size).simplify() != 0: + raise ValueError(f"{exact_size} is not divisible by {block_size}") + td_to_new = exact_size // block_size - 1 + td_step_new = td_step + new_ranges.append((0, td_to_new, td_step_new)) + return subsets.Range(new_ranges) + + +LocalSubsets = Dict[str, subsets.Range] +RankVariables = List[Optional[symbolic.symbol]] + + +def propagate_rank_local_subsets( + sdfg: SDFG, state: SDFGState, map_nodes: MapNodes, num_blocks: NumBlocks +) -> Tuple[RankVariables, Tuple[LocalSubsets, LocalSubsets]]: + """ + Compute the subset rank local subsets we need. + + For each rank-tiled parameter, the returned dictionary contains a mapping + from the original parameter name to the variable name of the rank index. + + The two-tuple contains mappings from array names to the subset expressions + for the inputs and outputs. + """ + me, mx = map_nodes + map_node = me.map + + # We need to reindex using "fake" block indices + # Create new symbols for the block indices and block sizes + used_vars: Set[str] = set(map_node.params) + rank_variables: RankVariables = [] + # build a new range for the fake local map + outer_range = [] + + for param, size, (start, end, + step), n_blocks in zip(map_node.params, + map_node.range.size_exact(), + map_node.range, num_blocks): + if n_blocks != 1: + rank_variable = utils.find_str_not_in_set(used_vars, + f"__block{param}") + used_vars.add(rank_variable) + outer_range.append(( + symbolic.pystr_to_symbolic( + f"{rank_variable} * ({size} / {n_blocks})"), + symbolic.pystr_to_symbolic( + f"({rank_variable} + 1) * ({size} / {n_blocks}) - 1"), + symbolic.pystr_to_symbolic(f"{step}"), + )) + + rank_variables.append(symbolic.pystr_to_symbolic(rank_variable)) + else: + rank_variables.append(symbolic.symbol(node.FULLY_REPLICATED_RANK)) + outer_range.append((start, end, step)) + + outer_range = subsets.Range(outer_range) + + # Collect all defined variables + scope_node_symbols = set(conn for conn in me.in_connectors + if not conn.startswith('IN_')) + defined_vars = [ + symbolic.pystr_to_symbolic(s) + for s in (state.symbols_defined_at(me).keys() + | sdfg.constants.keys()) if s not in scope_node_symbols + ] + defined_vars.extend(map(symbolic.pystr_to_symbolic, used_vars)) + defined_vars = set(defined_vars) + + results: Tuple[LocalSubsets, LocalSubsets] = ({}, {}) + for is_input, result in zip([True, False], results): + + # gather internal memlets by the out array they write to + internal_memlets: Dict[ + str, List[memlet.Memlet]] = collections.defaultdict(list) + edges = state.out_edges(me) if is_input else state.in_edges(mx) + for edge in edges: + if edge.data.is_empty(): + continue + internal_memlets[edge.data.data].append(edge.data) + + for arr_name, memlets in internal_memlets.items(): + # compute the rank local subset using propagation through our + # "fake" MPI map + rank_local = propagation.propagate_subset( + memlets, sdfg.arrays[arr_name], map_node.params, outer_range, + defined_vars, not is_input) + assert isinstance(rank_local.subset, subsets.Range) + result[arr_name] = rank_local.subset + return rank_variables, results + + +GlobalToLocal = List[Tuple[nodes.AccessNode, nodes.AccessNode, subsets.Range]] + + +def rank_tile_map( + sdfg: SDFG, + state: SDFGState, + map_nodes: MapNodes, + num_blocks: NumBlocks, + force_local_names: Optional[Dict[str, str]] = None, +) -> Tuple[RankVariables, GlobalToLocal, GlobalToLocal]: + """ + Tile the map according to the given block sizes, create rank-local views + for the reads and writes to global arrays. + + * Detect and handle local write-conflicts + * Handle views correctly + * Insert rank-local arrays and reroute the map edges to go to the local + views + + :param sdfg: The SDFG to operate on. + :param state: The state to operate on. + :param map_nodes: The map nodes to operate on. + :param num_blocks: The number of blocks to tile the map with in each dimension. + :param force_local_names: A dictionary mapping global array names to local array names. + :returns: created symbolic variables for each rank axis, and two lists of + tuples associating global AccessNodes, local AccessNodes and the + symbolic communication constraints. The first is for reads, the + second for writes. + """ + me, mx = map_nodes + + # Compute tiled map range + new_range = compute_tiled_map_range(me.map, num_blocks) + + rank_variables, (input_subsets, + output_subsets) = propagate_rank_local_subsets( + sdfg, state, map_nodes, num_blocks) + + # set the new range + me.map.range = new_range + + to_iter = itertools.chain( + zip(input_subsets.items(), itertools.repeat(True)), + zip(output_subsets.items(), itertools.repeat(False))) + + result_read: GlobalToLocal = [] + result_write: GlobalToLocal = [] + for (arr_name, new_subset), is_input in to_iter: + + # Determine the outer edge + outer_edges = state.in_edges(me) if is_input else state.out_edges(mx) + outer_edge_cands = [ + edge for edge in outer_edges if edge.data.data == arr_name + ] + if len(outer_edge_cands) > 1: + # FIXME this could be supported using a preprocessing + # transformation + raise NotImplementedError( + "Multiple outer edges to one array not implemented") + elif len(outer_edge_cands) == 0: + raise ValueError(f"No outer edge to {arr_name}") + outer_edge = outer_edge_cands[0] + + global_name = outer_edge.data.data + global_node = outer_edge.src if is_input else outer_edge.dst + + if not isinstance(global_node, nodes.AccessNode): + # FIXME tasklets should be replicated for each rank + raise NotImplementedError("Cannot handle non-access nodes yet") + elif isinstance(global_node.desc(sdfg), data.View): + raise NotImplementedError("Cannot handle views yet") + + if force_local_names is not None and global_name in force_local_names: + # reuse descriptor + local_name = force_local_names[global_name] + else: + local_name, _ = sdfg.add_transient( + name="local_" + global_name, + shape=[ + symbolic.overapproximate(r).simplify() + for r in new_subset.size_exact() + ], + dtype=sdfg.arrays[global_name].dtype, + find_new_name=True) + + local_node = state.add_access(local_name) + if is_input: + result_read.append((global_node, local_node, new_subset)) + else: + result_write.append((global_node, local_node, new_subset)) + + if is_input: + redirect_args = dict(new_src=local_node) + else: + redirect_args = dict(new_dst=local_node) + + new_edge = xfh.redirect_edge(state, + outer_edge, + new_data=local_name, + **redirect_args) + new_edge.data.subset = new_subset + + return rank_variables, result_read, result_write + + +def ordered_nodes_by_state( + sdfg: SDFG +) -> Dict[SDFGState, List[Union[nodes.Map, nodes.NestedSDFG]]]: + ordered_maps = {} + + for state in sdfg.nodes(): + top_level_nodes = set(state.scope_children()[None]) + result_nodes = [ + node.map if isinstance(node, nodes.MapEntry) else node + for node in sdfg_utils.dfs_topological_sort(state) + if isinstance(node, (nodes.MapEntry, + nodes.NestedSDFG)) and node in top_level_nodes + ] + ordered_maps[state] = result_nodes + return ordered_maps + + +def rank_tile_nested( + sdfg: SDFG, + state: SDFGState, + nnode: nodes.NestedSDFG, + schedule: DistributedSchedule, +) -> Tuple[NumBlocks, RankVariables, GlobalToLocal, GlobalToLocal]: + """ + Rank tile a nested SDFG. To make this viable in the cases we want to support + (typically just initialization for reductions) there are the following constraints: + + * Only a single ProcessGrid is used for the whole SDFG. + * No communication is added within the NSDFG (this enables fusability outside of it) + * As a result, schedules in the NSDFG must be 'consistent', in that if they + write to the same array, the communication constraints must be the same + """ + + nsdfg = nnode.sdfg + + # rule out unsqueezes + for edge in itertools.chain(state.in_edges(nnode), state.out_edges(nnode)): + if not utils.all_equal(edge.data.subset.size_exact(), + sdfg.arrays[edge.data.data].shape): + raise NotImplementedError( + "Cannot handle unsqueezes on NestedSDFG connectors") + + ordered_maps = ordered_nodes_by_state(nsdfg) + + rank_variables: Dict[str, int] = {} + num_fully_replicated: int = 0 + + constraints: Dict[str, subsets.Range] = {} + global_to_local: Dict[str, str] = {} + + # Rank tile each map, collecting the constraints by array + for nstate, map_nodes in ordered_maps.items(): + for map_node in map_nodes: + num_blocks = schedule[map_node] + + if len(num_blocks) != map_node.get_param_num(): + raise ValueError( + f"Schedule for {map_node} has {len(num_blocks)} " + f"block sizes, but {map_node.get_param_num()} are " + "required.") + + map_nodes = find_map_nodes(nstate, map_node) + + # modify the map to tile it + map_rank_variables, reads, writes = rank_tile_map( + nsdfg, + nstate, + map_nodes, + num_blocks, + force_local_names=global_to_local) + + # gather our current constraints, and delete the global access nodes + current_constraints = {} + for nglobal, nlocal, subset in reads: + current_constraints[nglobal.data] = subset + # update global mapping + global_to_local[nglobal.data] = nlocal.data + # delete global access node + nstate.remove_node(nglobal) + + for nglobal, nlocal, subset in writes: + # update global mapping + global_to_local[nglobal.data] = nlocal.data + if nglobal.data in current_constraints and current_constraints[ + nglobal.data] != subset: + raise ValueError( + "Inconsistent constraints found within map") + current_constraints[nglobal.data] = subset + nstate.remove_node(nglobal) + + # the tiled map now has its own rank variables. We now need to + # these variables consistent with the outer variables since the + # whole nested SDFG only has one ProcessGrid + me, mx = map_nodes + + renaming = {} + + map_block_size_per_symbol = { + s.name: bs + for s, bs in zip(map_rank_variables, num_blocks) + } + # the arrays that already have constraints from another map + for name, subset in current_constraints.items(): + if name not in constraints: + constraints[name] = subset + continue + + previous_constraint = constraints[name] + + is_symbol = lambda x: x.is_symbol + # try to match the constraints + + used_symbols = set(subset.free_symbols) + used_symbols |= set(previous_constraint.free_symbols) + + def new_symbol_name(name): + name = utils.find_str_not_in_set(used_symbols, name) + used_symbols.add(name) + return name + + pattern = copy.deepcopy(subset) + wilds = { + s: sp.Wild(new_symbol_name(f"i{i}"), + properties=[ + is_symbol, lambda x: rank_variables[x.name] + == map_block_size_per_symbol[s] + ]) + for i, s in enumerate(subset.free_symbols) + } + pattern.replace(wilds) + + def range_to_basic_expr(self): + exprs = [] + for (rb, re, rs), ts in zip(self.ranges, self.tile_sizes): + exprs.append(rb) + exprs.append(re) + exprs.append(ts) + return sp.Basic(*exprs) + + match = range_to_basic_expr(previous_constraint).match( + range_to_basic_expr(pattern)) + if match is None: + raise ValueError( + "Inconsistent constraints found within map") + + renaming.update(match) + + map_num_replicated = 0 + # now check that the block sizes are consistent + # The should already be consistent due to the sympy matching above, + # but let's make sure anyway + for v, bs in zip(map_rank_variables, num_blocks): + if v in renaming: + v = renaming[v] + + if v.name == node.FULLY_REPLICATED_RANK: + map_num_replicated += 1 + continue + + if v.name in rank_variables: + if rank_variables[v.name] != bs: + raise ValueError( + "Inconsistent block sizes found within NestedSDFG") + else: + rank_variables[v.name] = bs + num_fully_replicated = max(num_fully_replicated, + map_num_replicated) + + # Concretize process grid ordering + ordered_rank_variables = [] + block_sizes = [] + for ov, bs in rank_variables.items(): + ordered_rank_variables.append(ov) + block_sizes.append(bs) + ordered_rank_variables += [node.FULLY_REPLICATED_RANK + ] * num_fully_replicated + block_sizes += [1] * num_fully_replicated + + # Swap out the global connectors for local view connectors on the outside of the NSDFG + nnode.in_connectors = { + global_to_local[k]: v + for k, v in nnode.in_connectors.items() + } + nnode.out_connectors = { + global_to_local[k]: v + for k, v in nnode.out_connectors.items() + } + + # create outer transients for the local views + to_iter = itertools.chain( + zip(state.in_edges(nnode), itertools.repeat(True)), + zip(state.out_edges(nnode), itertools.repeat(False))) + + reads = [] + writes = [] + for edge, is_read in to_iter: + global_name = edge.dst_conn if is_read else edge.src_conn + local_name_in_nsdfg = global_to_local[global_name] + inner_desc = nsdfg.arrays[local_name_in_nsdfg] + outer_local_name = sdfg.add_datadesc( + name=local_name_in_nsdfg, + datadesc=copy.deepcopy(inner_desc), + find_new_name=True) + access = state.add_access(outer_local_name) + + if is_read: + reads.append((edge.src, access, constraints[global_name])) + redirect_args = dict(new_dst_conn=local_name_in_nsdfg, + new_src=access) + else: + writes.append((edge.dst, access, constraints[global_name])) + redirect_args = dict(new_src_conn=local_name_in_nsdfg, + new_dst=access) + + new_edge = xfh.redirect_edge(state, + edge, + new_data=outer_local_name, + **redirect_args) + new_edge.data.subset = constraints[global_name] + + for global_name, local_name in global_to_local.items(): + if not nsdfg.arrays[global_name].transient: + nsdfg.arrays[local_name].transient = False + del nsdfg.arrays[global_name] + + # specialize process grid vars inside the nsdfg + sdfg.specialize({v: 0 for v in ordered_rank_variables}) + ordered_rank_variables = list( + map(symbolic.pystr_to_symbolic, ordered_rank_variables)) + return block_sizes, ordered_rank_variables, reads, writes + + +def lower(sdfg: SDFG, schedule: DistributedSchedule): + """ + Attempt to lower the SDFG to a SPMD MPI SDFG to distribute computation. + + The schedule defines the size of the process grids used to compute each of the parallel maps. + + :param sdfg: The SDFG to lower. + :param schedule: The schedule to use. + :note: Operates in-place. + """ + missing = set(distr_utils.all_top_level_maps(sdfg)).difference( + schedule.keys()) + + if missing: + raise ValueError( + f"Missing schedule for maps: {', '.join(map(lambda x: x.label, missing))}" + ) + + # Order the schedule topologically for each state + ordered_nodes = ordered_nodes_by_state(sdfg) + + # each map has a main process grid + # with the dimension given by the schedule + process_grids: Dict[str, RankVariables] = {} + + for state, top_level_nodes in ordered_nodes.items(): + for top_level_node in top_level_nodes: + if isinstance(top_level_node, nodes.NestedSDFG): + num_blocks, rank_variables, reads, writes = rank_tile_nested( + sdfg, state, top_level_node, schedule) + else: + map_node: nodes.Map = top_level_node + num_blocks = schedule[map_node] + + if len(num_blocks) != map_node.get_param_num(): + raise ValueError( + f"Schedule for {map_node} has {len(num_blocks)} " + f"block sizes, but {map_node.get_param_num()} are " + "required.") + + map_nodes = find_map_nodes(state, map_node) + + # modify the map to tile it + rank_variables, reads, writes = rank_tile_map( + sdfg, state, map_nodes, num_blocks) + rank_variable_names: List[str] = list( + map(lambda s: s.name, rank_variables)) + + # Create a process grid that will be used for communication + process_grid_name = sdfg.add_pgrid(shape=num_blocks) + distr_utils.initialize_fields(state, [ + f'MPI_Comm {process_grid_name}_comm;', + f'MPI_Group {process_grid_name}_group;', + f'int {process_grid_name}_coords[{len(num_blocks)}];', + f'int {process_grid_name}_dims[{len(num_blocks)}];', + f'int {process_grid_name}_rank;', + f'int {process_grid_name}_size;', + f'bool {process_grid_name}_valid;', + ]) + + process_grids[process_grid_name] = rank_variables + + to_iter = itertools.chain(zip(reads, itertools.repeat(True)), + zip(writes, itertools.repeat(False))) + + for (nglobal, nlocal, subset), is_read in to_iter: + if not is_read: + # FIXME We don't need to communicate this if it is a + # derived schedule and one of our siblings comes after us. + pass + + full_subset = subsets.Range.from_array(nglobal.desc(sdfg)) + + global_params = dict(rank_variables=[], + pgrid=None, + subset=full_subset, + global_array=nglobal.data) + + local_params = dict(rank_variables=rank_variable_names, + pgrid=process_grid_name, + subset=subset, + global_array=nglobal.data) + + global_prefix = "src_" if is_read else "dst_" + local_prefix = "dst_" if is_read else "src_" + + add_prefix = lambda d, p: {p + k: v for k, v in d.items()} + + comm = node.DistributedMemlet( + name="communicate_" + nglobal.data, + **add_prefix(global_params, global_prefix), + **add_prefix(local_params, local_prefix)) + + state.add_node(comm) + src = nglobal if is_read else nlocal + dst = nlocal if is_read else nglobal + state.add_edge(src, None, comm, None, + sdfg.make_array_memlet(src.data)) + state.add_edge(comm, None, dst, None, + sdfg.make_array_memlet(dst.data)) + + sdfg.validate() + utils.expand_nodes( + sdfg, predicate=lambda n: isinstance(n, node.DistributedMemlet)) + sdfg.validate() + + # Now that we are done lowering, we can instatiate the process grid + # variables with zero, since each rank only sees its section of the array + for _, variables in process_grids.items(): + repl_dict = {v.name: 0 for v in variables} + sdfg.specialize(repl_dict) diff --git a/daceml/distributed/utils.py b/daceml/distributed/utils.py new file mode 100644 index 00000000..e92560e0 --- /dev/null +++ b/daceml/distributed/utils.py @@ -0,0 +1,171 @@ +from typing import List, Iterator, Tuple, Union, Dict + +import numpy as np +import pytest + +import dace +from dace import SDFG, SDFGState, nodes, dtypes +from dace.sdfg import utils as sdfg_utils +from dace.libraries import mpi +from daceml.util import utils + + +def initialize_fields(state: SDFGState, fields: List[str]): + """ + Add a dummy library node to initialize the given fields to the SDFG + """ + sdfg = state.parent + dummy = mpi.Dummy("initialize_fields", fields) + + state.add_node(dummy) + + # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. + dummy_name, scal = sdfg.add_scalar("dummy", + dace.int32, + transient=True, + find_new_name=True) + wnode = state.add_write(dummy_name) + state.add_edge(dummy, '__out', wnode, None, + dace.Memlet.from_array(dummy_name, scal)) + + +def all_top_level_maps( + sdfg: SDFG, + yield_parent=False +) -> Iterator[Union[nodes.Map, Tuple[nodes.Map, SDFGState]]]: + for state in sdfg.nodes(): + for node in state.scope_children()[None]: + if isinstance(node, nodes.MapEntry): + if yield_parent: + yield node.map, state + else: + yield node.map + elif isinstance(node, nodes.NestedSDFG): + # also check top-level nested SDFGs + yield from all_top_level_maps(node.sdfg, yield_parent) + + +def add_debug_rank_cords_tasklet(sdfg: SDFG): + """ + Add a tasklet to the start of the SDFG that will print out the process coordinates in each process grid. + """ + new_state = sdfg.add_state_before(sdfg.start_state, 'debug_MPI') + + code = """{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + """ + + for grid_name, desc in sdfg.process_grids.items(): + code += """ + if (__state->{grid_name}_rank == 0) {{ + printf("{grid_name} dimensions: "); + for (int i = 0; i < {grid_dims}; i++) {{ + printf("%d ", __state->{grid_name}_dims[i]); + }} + printf("\\n"); + }} + + printf("Hello from global rank %d, rank %d in grid {grid_name}, with coords: ", rank, __state->{grid_name}_rank); + + for (int i = 0; i < {grid_dims}; i++) {{ + printf("%d ", __state->{grid_name}_coords[i]); + }} + printf("\\n"); + """.format(grid_name=grid_name, grid_dims=len(desc.shape)) + code += "}" + + # add a tasklet that writes to nothing + tasklet = nodes.Tasklet("debug_MPI", {}, {"__out"}, code, + dtypes.Language.CPP) + + new_state.add_node(tasklet) + + # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. + dummy_name, scal = sdfg.add_scalar("dummy", + dace.int32, + transient=True, + find_new_name=True) + wnode = new_state.add_write(dummy_name) + new_state.add_edge(tasklet, '__out', wnode, None, + dace.Memlet.from_array(dummy_name, scal)) + + +def add_debugprint_tasklet(sdfg: SDFG, state: SDFGState, + node: nodes.AccessNode): + """ + Insert a tasklet that just prints out the given data + """ + desc = node.desc(sdfg) + loops = "\n".join("for (int i{v} = 0; i{v} < {s}; i{v}++)".format(v=i, s=s) + for i, s in enumerate(desc.shape)) + code = """ + {loops} + {{ + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + printf("RANK %d: {name}[%d] = %d\\n", rank, {indices}, {name}[{indices}]); + }} + """.format(name=node.data, + loops=loops, + indices=",".join( + ["i{}".format(i) for i in range(len(desc.shape))])) + + # add a tasklet that writes to nothing + tasklet = nodes.Tasklet("print_" + node.data, {"__in"}, {"__out"}, code, + dtypes.Language.CPP) + + state.add_node(tasklet) + state.add_edge(node, None, tasklet, '__in', + sdfg.make_array_memlet(node.data)) + + # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations. + dummy_name, scal = sdfg.add_scalar("dummy", + dace.int32, + transient=True, + find_new_name=True) + state.add_edge(tasklet, '__out', state.add_write(dummy_name), None, + dace.Memlet.from_array(dummy_name, scal)) + + +def arange_with_size(size): + return np.arange(utils.prod(size), dtype=np.int64).reshape(size).copy() + + +def find_map_containing(sdfg, name) -> nodes.Map: + cands = [] + for node, state in all_top_level_maps(sdfg, yield_parent=True): + if name in node.label: + cands.append(node) + if len(cands) == 1: + return cands[0] + else: + raise ValueError("Found {} candidates for map name {}".format( + len(cands), name)) + + +def compile_and_call(sdfg, inputs: Dict[str, np.ndarray], + expected_output: np.ndarray, num_required_ranks: int): + MPI = pytest.importorskip("mpi4py.MPI") + commworld = MPI.COMM_WORLD.Dup() + commworld.Barrier() + rank = commworld.Get_rank() + size = commworld.Get_size() + + if size < num_required_ranks: + raise ValueError( + "This test requires at least {} ranks".format(num_required_ranks)) + + func = sdfg_utils.distributed_compile(sdfg, commworld) + + if rank == 0: + result = func(**inputs) + np.testing.assert_allclose(result, expected_output) + else: + dummy_inputs = { + k: np.zeros_like(v, shape=(1, )) + for k, v in inputs.items() + } + func(**dummy_inputs) + commworld.Barrier() + commworld.Free() diff --git a/daceml/util/utils.py b/daceml/util/utils.py index 0b3d503f..c8139a9c 100644 --- a/daceml/util/utils.py +++ b/daceml/util/utils.py @@ -3,6 +3,7 @@ import logging from typing import Optional, Set, Callable +import itertools from functools import wraps import dace @@ -332,3 +333,37 @@ def all_equal(a, b) -> bool: if len(a) != len(b): return False return all(x == y for x, y in zip(a, b)) + + +# From https://stackoverflow.com/a/40596355 +def strict_zip(*iterables): + """ + Zip iterables, requiring that they are equal in length. + """ + class ExhaustedError(Exception): + def __init__(self, index): + self.index = index + + def raising_iter(i): + raise ExhaustedError(i) + yield + + def terminate_iter(i, iterable): + return itertools.chain(iterable, raising_iter(i)) + + iterators = [terminate_iter(*args) for args in enumerate(iterables)] + try: + yield from zip(*iterators) + except ExhaustedError as exc: + index = exc.index + if index > 0: + raise RuntimeError( + 'iterable {} exhausted first'.format(index)) from None + # Check that all other iterators are also exhausted. + for i, iterator in enumerate(iterators[1:], start=1): + try: + next(iterator) + except ExhaustedError: + pass + else: + raise RuntimeError('iterable {} is longer'.format(i)) from None diff --git a/setup.py b/setup.py index 94df8ad0..744328a0 100644 --- a/setup.py +++ b/setup.py @@ -56,4 +56,5 @@ 'sphinx', 'sphinx_rtd_theme', 'sphinx-autodoc-typehints', 'sphinx-gallery', 'matplotlib', 'jinja2<3.1' ], + 'distributed': ['mpi4py'], }) diff --git a/tests/distributed/mpi_mute.py b/tests/distributed/mpi_mute.py new file mode 100644 index 00000000..5d600510 --- /dev/null +++ b/tests/distributed/mpi_mute.py @@ -0,0 +1,24 @@ +""" +Pytest plugin to mute all ranks that are not 0 + +Requires mpi4py +""" + +import pytest + +from mpi4py import MPI + + +def pytest_addoption(parser): + parser.addoption("--unmute-all-ranks", + action="store_true", + help="Unmute all MPI ranks") + + +@pytest.mark.trylast +def pytest_configure(config): + unmute = config.getoption("--unmute-all-ranks") + if MPI.COMM_WORLD.Get_rank() != 0 and not unmute: + # unregister the standard reporter + standard_reporter = config.pluginmanager.getplugin('terminalreporter') + config.pluginmanager.unregister(standard_reporter) diff --git a/tests/distributed/test_elementwise.py b/tests/distributed/test_elementwise.py new file mode 100644 index 00000000..f9985b8e --- /dev/null +++ b/tests/distributed/test_elementwise.py @@ -0,0 +1,59 @@ +""" +These tests expect to be with 4 ranks +""" +import pytest + +import dace + +from daceml.util import utils +from daceml.distributed import schedule +from daceml.distributed.utils import find_map_containing, compile_and_call, arange_with_size + + +@pytest.mark.parametrize("sizes", [ + [2], + [4], +]) +def test_elementwise_1d(sizes): + @dace + def program(x: dace.int64[64]): + return x + 5 + + sdfg = program.to_sdfg() + + map_entry = find_map_containing(sdfg, "") + schedule.lower(sdfg, {map_entry: sizes}) + + X = arange_with_size([64]) + expected = X + 5 + compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes)) + + +@pytest.mark.parametrize( + "sizes", + [ + [2, 1, 1], # fully replicate broadcasted => use 2d broadcast grid + [2, 2, 1], # no broadcast grid, 1d scatter grid + [2, 2, 2], # no broadcast grid, 2d scatter grid + [1, 2, 1], + ]) +def test_bcast_simple(sizes): + @dace + def program(x: dace.int64[4, 8, 16], y: dace.int64[8, 16]): + return x + y + + sdfg = program.to_sdfg() + sdfg.expand_library_nodes() + sdfg.simplify() + + map_entry = find_map_containing(sdfg, "") + schedule.lower(sdfg, {map_entry: sizes}) + + X = arange_with_size([4, 8, 16]) + Y = arange_with_size([8, 16]) + expected = X + Y + + compile_and_call(sdfg, { + 'x': X.copy(), + 'y': Y.copy() + }, expected, utils.prod(sizes)) diff --git a/tests/distributed/test_nested.py b/tests/distributed/test_nested.py new file mode 100644 index 00000000..cb4c4bbc --- /dev/null +++ b/tests/distributed/test_nested.py @@ -0,0 +1,73 @@ +import pytest +import numpy as np + +import dace +from dace import nodes +from dace.transformation import interstate +from dace.sdfg import utils as sdfg_utils + +from daceml.util import utils +from daceml.distributed import schedule +from daceml.distributed.utils import find_map_containing, compile_and_call, arange_with_size + + +@pytest.mark.parametrize( + "sizes", + [ + [1, 1], + [2, 1], + [2, 2], # parallelize along the reduction axis with MPI reduce + [2, 4], # parallelize along the reduction axis with MPI reduce + ]) +def test_reduce_simple(sizes): + @dace + def program(x: dace.int64[16, 16]): + return np.add.reduce(x, axis=1) + + sdfg = program.to_sdfg() + sdfg.expand_library_nodes() + # don't expand the NSDFG, this test tests that + assert any( + isinstance(n, nodes.NestedSDFG) for state in sdfg.nodes() + for n in state.nodes()) + + reduce = find_map_containing(sdfg, "reduce_output") + init = find_map_containing(sdfg, "init") + + schedule.lower(sdfg, {reduce: sizes, init: sizes[:1]}) + + X = arange_with_size([16, 16]) + expected = X.copy().sum(axis=1) + + compile_and_call(sdfg, {'x': X.copy()}, expected, utils.prod(sizes)) + + +def test_nested_two_maps(): + @dace + def nested(x): + y = np.zeros_like(x, shape=(16, 16)) + for i, j, k in dace.map[0:16, 0:16, 0:32]: + with dace.tasklet: + inp << x[i, k, j] + out = inp + out >> y(1, lambda x, y: x + y)[i, j] + return y + + @dace + def program(x: dace.int64[16, 32, 16]): + return nested(x) + + sdfg = program.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated(interstate.StateFusion) + sdfg.sdfg_list[1].apply_transformations_repeated(interstate.InlineSDFG) + sdfg.expand_library_nodes() + # don't expand the NSDFG, this test tests that + assert any( + isinstance(n, nodes.NestedSDFG) for state in sdfg.nodes() + for n in state.nodes()) + + elementwise = find_map_containing(sdfg, "test_nested_nested") + init = find_map_containing(sdfg, "full__map") + + schedule.lower(sdfg, {elementwise: [2, 2, 2], init: [2, 2]}) + sdfg.validate()