From fcd57ec09c2e0c518ef20779a35a2b9a14693b7d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 14 Mar 2020 18:27:10 +0100 Subject: [PATCH 01/32] Fix pytest-confusing syntax --- tests/library/blas_dot_test.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/library/blas_dot_test.py b/tests/library/blas_dot_test.py index 3550f43078..9988c1cc8a 100644 --- a/tests/library/blas_dot_test.py +++ b/tests/library/blas_dot_test.py @@ -94,7 +94,7 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): ############################################################################### -def test_dot(implementation, dtype, sdfg): +def _test_dot(implementation, dtype, sdfg): try: dot = sdfg.compile() except (CompilerConfigurationError, CompilationError): @@ -126,17 +126,19 @@ def test_dot(implementation, dtype, sdfg): print("Test ran successfully for {}.".format(implementation)) -############################################################################### +def test_dot(): + _test_dot("32-bit pure SDFG", np.float32, make_sdfg("pure", dace.float32)) + _test_dot("64-bit pure SDFG", np.float64, make_sdfg("pure", dace.float64)) + _test_dot("32-bit MKL", np.float32, make_sdfg("MKL", dace.float32)) + _test_dot("64-bit MKL", np.float64, make_sdfg("MKL", dace.float64)) + _test_dot("32-bit cuBLAS", np.float32, + make_sdfg("cuBLAS", dace.float32, dace.StorageType.GPU_Global)) + _test_dot("64-bit cuBLAS", np.float64, + make_sdfg("cuBLAS", dace.float64, dace.StorageType.GPU_Global)) -if __name__ == "__main__": - test_dot("32-bit pure SDFG", np.float32, make_sdfg("pure", dace.float32)) - test_dot("64-bit pure SDFG", np.float64, make_sdfg("pure", dace.float64)) - test_dot("32-bit MKL", np.float32, make_sdfg("MKL", dace.float32)) - test_dot("64-bit MKL", np.float64, make_sdfg("MKL", dace.float64)) - test_dot("32-bit cuBLAS", np.float32, - make_sdfg("cuBLAS", dace.float32, dace.StorageType.GPU_Global)) - test_dot("64-bit cuBLAS", np.float64, - make_sdfg("cuBLAS", dace.float64, dace.StorageType.GPU_Global)) +############################################################################### +if __name__ == "__main__": + test_dot() ############################################################################### From 134e0ebf122c92cad3bb04ce2d5d9eb45b92b793 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sat, 14 Mar 2020 19:38:07 +0100 Subject: [PATCH 02/32] Zero memory for return values on the python side (to be safe) --- dace/codegen/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/codegen/compiler.py b/dace/codegen/compiler.py index b8630acfae..4fea100437 100644 --- a/dace/codegen/compiler.py +++ b/dace/codegen/compiler.py @@ -293,7 +293,7 @@ def _initialize_return_values(self, kwargs): self._return_arrays.append( np.ndarray([symbolic.evaluate(s, syms) for s in arr.shape], arr.dtype.type, - buffer=np.ndarray( + buffer=np.zeros( [symbolic.evaluate(arr.total_size, syms)], arr.dtype.type), strides=[ From a30b99c693a2cce6deb9c957c2989a71495cfadc Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 13:06:52 +0100 Subject: [PATCH 03/32] Fix add_mapped_tasklet API for disjoint input/output_nodes and connectors --- dace/sdfg.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/dace/sdfg.py b/dace/sdfg.py index e21166bff1..aab79a6449 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -3059,16 +3059,22 @@ def add_mapped_tasklet( self.add_nodes_from([map_entry, tasklet, map_exit]) # Create access nodes + inpdict = {} + outdict = {} if external_edges: + input_nodes = input_nodes or {} + output_nodes = output_nodes or {} input_data = set(memlet.data for memlet in inputs.values()) output_data = set(memlet.data for memlet in outputs.values()) - inpdict = input_nodes or {} - outdict = output_nodes or {} - if not input_nodes: - for inp in input_data: + for inp in input_data: + if inp in input_nodes: + inpdict[inp] = input_nodes[inp] + else: inpdict[inp] = self.add_read(inp) - if not output_nodes: - for out in output_data: + for out in output_data: + if out in output_nodes: + outdict[out] = output_nodes[out] + else: outdict[out] = self.add_write(out) # Connect inputs from map to tasklet From c3f1316557799a1889dc867fa3c67d2e0041694f Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 13:26:08 +0100 Subject: [PATCH 04/32] New API for adding a state before/after an existing state --- dace/sdfg.py | 38 +++++++++++++++++++++++++++ tests/add_state_api_test.py | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 tests/add_state_api_test.py diff --git a/dace/sdfg.py b/dace/sdfg.py index aab79a6449..1bb5250ca1 100644 --- a/dace/sdfg.py +++ b/dace/sdfg.py @@ -1273,6 +1273,44 @@ def add_state(self, label=None, is_start_state=False): self.add_node(state, is_start_state=is_start_state) return state + def add_state_before(self, state: 'SDFGState', label=None, + is_start_state=False) -> 'SDFGState': + """ Adds a new SDFG state before an existing state, reconnecting + predecessors to it instead. + :param state: The state to prepend the new state before. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this + state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.in_edges(state): + self.remove_edge(e) + self.add_edge(e.src, new_state, e.data) + # Add unconditional connection between the new state and the current + self.add_edge(new_state, state, ed.InterstateEdge()) + return new_state + + def add_state_after(self, state: 'SDFGState', label=None, + is_start_state=False) -> 'SDFGState': + """ Adds a new SDFG state after an existing state, reconnecting + it to the successors instead. + :param state: The state to append the new state after. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this + state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.out_edges(state): + self.remove_edge(e) + self.add_edge(new_state, e.dst, e.data) + # Add unconditional connection between the current and the new state + self.add_edge(state, new_state, ed.InterstateEdge()) + return new_state + def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ index = 0 diff --git a/tests/add_state_api_test.py b/tests/add_state_api_test.py new file mode 100644 index 0000000000..a6a13fe16d --- /dev/null +++ b/tests/add_state_api_test.py @@ -0,0 +1,51 @@ +import dace +import numpy as np + + +@dace.program +def control_flow(A: dace.float64[10]): + if A[0] < 0.5: + for i in range(5): + A[i] *= 2 + # TODO: Disabled due to bug in control flow + # else: + # for i in range(5, 10): + # A[i] *= 2 + + +def _configure(): + A = np.random.rand(10) + expected = A.copy() + if A[0] < 0.5: + expected[0:5] *= 2 + # else: + # expected[5:10] *= 2 + sdfg = control_flow.to_sdfg() + return sdfg, A, expected + + +def test_state_before(): + sdfg, A, expected = _configure() + old_states = list(sdfg.nodes()) + for state in old_states: + sdfg.add_state_before(state) + + assert sdfg.number_of_nodes() == 2 * len(old_states) + sdfg(A=A) + assert np.allclose(A, expected) + + +def test_state_after(): + sdfg, A, expected = _configure() + old_states = list(sdfg.nodes()) + for state in old_states: + sdfg.add_state_after(state) + + assert sdfg.number_of_nodes() == 2 * len(old_states) + sdfg(A=A) + assert np.allclose(A, expected) + + +if __name__ == '__main__': + test_state_before() + test_state_after() From dba0f1cb1e914baff7eae6a599e27acd1cf969b5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 13:26:37 +0100 Subject: [PATCH 05/32] When inlining SDFGs, try to squeeze internal memlets first if dimensionality mismatches --- dace/transformation/helpers.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 73787a0242..92bd02a280 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -260,10 +260,13 @@ def unsqueeze_memlet(internal_memlet: Memlet, external_memlet: Memlet): result.subset.unsqueeze(to_unsqueeze) elif len(internal_memlet.subset) > len(external_memlet.subset): - raise ValueError('Unexpected extra dimensions in internal memlet ' - 'while un-squeezing memlet.\nExternal memlet: %s\n' - 'Internal memlet: %s' % - (external_memlet, internal_memlet)) + # Try to squeeze internal memlet + result.subset.squeeze() + if len(result.subset) != len(external_memlet.subset): + raise ValueError('Unexpected extra dimensions in internal memlet ' + 'while un-squeezing memlet.\nExternal memlet: %s\n' + 'Internal memlet: %s' % + (external_memlet, internal_memlet)) result.subset.offset(external_memlet.subset, False) From 00cb92c7561d9ae0e38ac63daaf547ab492cac46 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 13:31:53 +0100 Subject: [PATCH 06/32] Clean up dependencies vs. test dependencies --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index e431d0923c..508d023e75 100644 --- a/setup.py +++ b/setup.py @@ -57,9 +57,9 @@ }, include_package_data=True, install_requires=[ - 'numpy', 'networkx >= 2.2', 'astunparse', 'sympy', 'scipy', 'pyyaml', - 'absl-py', 'ply', 'websockets', 'graphviz', 'requests', 'flask', + 'numpy', 'networkx >= 2.2', 'astunparse', 'sympy', 'pyyaml', + 'ply', 'websockets', 'graphviz', 'requests', 'flask', 'scikit-build', 'cmake', 'aenum' ], - tests_require=['coverage'], + tests_require=['coverage', 'scipy', 'absl-py', 'opt_einsum'], scripts=['scripts/diode', 'scripts/dacelab', 'scripts/sdfv']) From e838e1af4960dc060548518ae6994e7169cc2ea1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 13:32:45 +0100 Subject: [PATCH 07/32] Einsum function replacement with opt_einsum optimizer and GEMM specialization WIP: GEMM specialization is CUBLAS-only Co-authored-by: am-ivanov --- dace/frontend/common/__init__.py | 1 + dace/frontend/common/einsum.py | 519 +++++++++++++++++++++++++++++++ tests/numpy/einsum_test.py | 77 +++++ 3 files changed, 597 insertions(+) create mode 100644 dace/frontend/common/einsum.py create mode 100644 tests/numpy/einsum_test.py diff --git a/dace/frontend/common/__init__.py b/dace/frontend/common/__init__.py index 343f8cadd9..6fb598d258 100644 --- a/dace/frontend/common/__init__.py +++ b/dace/frontend/common/__init__.py @@ -3,3 +3,4 @@ from .op_impl import constant_array_multiplication from .op_impl import matrix_transpose, matrix_transpose_s from .op_impl import matrix_pointwise_op +from .einsum import create_einsum_sdfg diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py new file mode 100644 index 0000000000..9cb1a8fb69 --- /dev/null +++ b/dace/frontend/common/einsum.py @@ -0,0 +1,519 @@ +""" Classes to handle Einstein-notation sums (einsum) as a library node. """ +from functools import reduce +from itertools import chain +from string import ascii_letters +from typing import Dict, Optional + +from dace import dtypes, symbolic +from dace.graph.nodes import AccessNode +from dace.graph.edges import InterstateEdge +from dace.sdfg import SDFG, SDFGState +from dace.memlet import Memlet +from dace.frontend.common import op_repository as oprepo +from dace.frontend.common.op_impl import _to_blastype + + +def _is_sequential(index_list): + if not index_list: + return True + index_list = sorted(index_list) + smallest_elem = index_list[0] + return index_list == list( + range(smallest_elem, smallest_elem + len(index_list))) + + +class EinsumParser(object): + """ String parser for einsum. """ + def __init__(self, string): + inout = string.split('->') + if len(inout) == 1: + inputs, output = string, '' + else: + inputs, output = inout + + for char in chain(inputs, output): + if char not in ascii_letters + ',': + raise ValueError( + 'Invalid einsum string, subscript must contain' + ' letters, commas, and "->".') + + inputs = inputs.split(',') + + # No output given, assumed all "free" subscripts in inputs + if len(inout) == 1: + # Find intersection and union of all inputs for the non-outputs + # and free inputs + nonfree = set() + free = set() + for i, inp in enumerate(inputs): + for var in set(inp): + if (all(var not in set(s) for s in inputs[i + 1:]) + and var not in nonfree): + free.add(var) + else: + nonfree.add(var) + output = ''.join(sorted(free)) + + self.inputs = inputs + self.output = output + if len(inputs) != 2: + return + + # Special case: contracting two tensors + a, b = inputs + c = output + a_vars = set(a) + b_vars = set(b) + ab_vars = a_vars.union(b_vars) + c_vars = set(c) + if not ab_vars.issuperset(c_vars): + raise ValueError('Einsum subscript string includes outputs that do' + ' not appear as an input') + + batch_vars = a_vars.intersection(b_vars).intersection(c_vars) + sum_vars = a_vars.intersection(b_vars) - c_vars + a_only_vars = a_vars - sum_vars - batch_vars + b_only_vars = b_vars - sum_vars - batch_vars + + self.a_batch = [i for i, d in enumerate(a) if d in batch_vars] + self.a_sum = [i for i, d in enumerate(a) if d in sum_vars] + self.a_only = [i for i, d in enumerate(a) if d in a_only_vars] + + self.b_batch = [i for i, d in enumerate(b) if d in batch_vars] + self.b_sum = [i for i, d in enumerate(b) if d in sum_vars] + self.b_only = [i for i, d in enumerate(b) if d in b_only_vars] + + self.c_a_only = [i for i, d in enumerate(c) if d in a_only_vars] + self.c_b_only = [i for i, d in enumerate(c) if d in b_only_vars] + self.c_batch = [i for i, d in enumerate(c) if d in batch_vars] + + def is_bmm(self): + if len(self.inputs) != 2: + return False + for key, val in self.fields().items(): + if not _is_sequential(val): + return False + return True + + def fields(self): + return { + fname: fval + for fname, fval in self.__dict__.items() + if fname not in ('inputs', 'output') + } + + def __str__(self): + return str(self.__dict__) + + def __repr__(self): + return str(self) + + +# TODO: Remove once library nodes are used +cublas_initialized = False + + +def create_batch_gemm_sdfg(dtype, strides): + # TODO: Use MatMult library node + ######################### + sdfg = SDFG('einsum') + state = sdfg.add_state() + BATCH, M, K, N, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( + symbolic.symbol(s) for s in [ + 'BATCH', 'M', 'K', 'N', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', + 'sCM', 'sCN', 'sCB' + ]) + + batched = strides['BATCH'] != 1 + + sdfg.add_array('X', + dtype=dtype, + shape=[BATCH, M, K] if batched else [M, K], + strides=[sAB, sAM, sAK] if batched else [sAM, sAK], + storage=dtypes.StorageType.GPU_Global) + sdfg.add_array('Y', + dtype=dtype, + shape=[BATCH, K, N] if batched else [K, N], + strides=[sBB, sBK, sBN] if batched else [sBK, sBN], + storage=dtypes.StorageType.GPU_Global) + sdfg.add_array('Z', + dtype=dtype, + shape=[BATCH, M, N] if batched else [M, N], + strides=[sCB, sCM, sCN] if batched else [sCM, sCN], + storage=dtypes.StorageType.GPU_Global) + + gX = state.add_read('X') + gY = state.add_read('Y') + gZ = state.add_access('Z') + + # possible order (C, row based) of dimensions in input array + # and computed result based on + # 1. N/T - transpose flag in cublas + # 2. LR/RL - order in which A and B are passed into cublas + # k m, n k -> n m (LR, N, N) + # m k, n k -> n m (LR, T, N) + # k m, k n -> n m (LR, N, T) + # m k, k n -> n m (LR, T, T) + # m k, k n -> m n (RL, N, N) + # m k, n k -> m n (RL, N, T) + # k m, k n -> m n (RL, T, N) + # k m, n k -> m n (RL, T, T) + # | | | + # use these 3 to detect correct option + + opts = { + 'mkm': { + 'swap': False, + 'lda': sAK, + 'ldb': sBN, + 'ldc': sCN, + 'ta': 'N', + 'tb': 'N' + }, + 'kkm': { + 'swap': False, + 'lda': sAM, + 'ldb': sBN, + 'ldc': sCN, + 'ta': 'T', + 'tb': 'N' + }, + 'mnm': { + 'swap': False, + 'lda': sAK, + 'ldb': sBK, + 'ldc': sCN, + 'ta': 'N', + 'tb': 'T' + }, + 'knm': { + 'swap': False, + 'lda': sAM, + 'ldb': sBK, + 'ldc': sCN, + 'ta': 'T', + 'tb': 'T' + }, + 'knn': { + 'swap': True, + 'lda': sAM, + 'ldb': sBK, + 'ldc': sCM, + 'ta': 'N', + 'tb': 'N' + }, + 'kkn': { + 'swap': True, + 'lda': sAM, + 'ldb': sBN, + 'ldc': sCM, + 'ta': 'N', + 'tb': 'T' + }, + 'mnn': { + 'swap': True, + 'lda': sAK, + 'ldb': sBK, + 'ldc': sCM, + 'ta': 'T', + 'tb': 'N' + }, + 'mkn': { + 'swap': True, + 'lda': sAK, + 'ldb': sBN, + 'ldc': sCM, + 'ta': 'T', + 'tb': 'T' + }, + } + + if strides['sAM'] == 1: + optA = 'm' + elif strides['sAK'] == 1: + optA = 'k' + else: + raise Exception("sAM or sAK should be 1") + + if strides['sBK'] == 1: + optB = 'k' + elif strides['sBN'] == 1: + optB = 'n' + else: + raise Exception("sBK or sBN should be 1") + + if strides['sCM'] == 1: + optC = 'm' + elif strides['sCN'] == 1: + optC = 'n' + else: + raise Exception("sCM or sCN should be 1") + + opt = opts[optA + optB + optC] + + opt['sta'] = sAB + opt['stb'] = sBB + opt['stc'] = sCB + opt['x'] = 'x' + opt['y'] = 'y' + opt['M'] = M + opt['N'] = N + if opt['swap']: + opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] + opt['sta'], opt['stb'] = opt['stb'], opt['sta'] + opt['x'], opt['y'] = opt['y'], opt['x'] + opt['ta'], opt['tb'] = opt['tb'], opt['ta'] + opt['M'], opt['N'] = opt['N'], opt['M'] + + global cublas_initialized + if not cublas_initialized: + code_global = ''' + #include + cublasHandle_t handle; + ''' + code_init = 'cublasCreate(&handle);' + code_exit = 'cublasDestroy(handle);' + cublas_initialized = True + else: + code_global = '' + code_init = '' + code_exit = '' + + cublas_gemm = 'cublas%sgemm' % _to_blastype(dtype.type) + + if not batched: + code = ''' + cublasSetStream(handle, __dace_current_stream); + {c_dtype} alpha_unused = 1.0, beta_unused = 0.0; + {cublas_gemm}(handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + &alpha_unused, + {x}, {lda}, + {y}, {ldb}, + &beta_unused, + z, {ldc}); + '''.format(BATCH=BATCH, + M=opt['M'], + N=opt['N'], + K=K, + lda=opt['lda'], + ldb=opt['ldb'], + ldc=opt['ldc'], + x=opt['x'], + y=opt['y'], + ta=opt['ta'], + tb=opt['tb'], + c_dtype=dtype.ctype, + cublas_gemm=cublas_gemm) + else: + code = ''' + cublasSetStream(handle, __dace_current_stream); + {c_dtype} alpha_unused = 1.0, beta_unused = 0.0; + {cublas_gemm}StridedBatched(handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + &alpha_unused, + {x}, {lda}, {stride_a}, + {y}, {ldb}, {stride_b}, + &beta_unused, + z, {ldc}, {stride_c}, + {BATCH}); + '''.format(BATCH=BATCH, + M=opt['M'], + N=opt['N'], + K=K, + lda=opt['lda'], + ldb=opt['ldb'], + ldc=opt['ldc'], + stride_a=opt['sta'], + stride_b=opt['stb'], + stride_c=opt['stc'], + x=opt['x'], + y=opt['y'], + ta=opt['ta'], + tb=opt['tb'], + c_dtype=dtype.ctype, + cublas_gemm=cublas_gemm) + + cublas_tasklet = state.add_tasklet(name="cublas_tasklet", + inputs={'x', 'y'}, + outputs={'z'}, + code=code, + code_global=code_global, + code_init=code_init, + code_exit=code_exit, + language=dtypes.Language.CPP) + + state.add_edge(gX, None, cublas_tasklet, 'x', + Memlet.from_array(gX, gX.desc(sdfg))) + state.add_edge(gY, None, cublas_tasklet, 'y', + Memlet.from_array(gY, gY.desc(sdfg))) + state.add_edge(cublas_tasklet, 'z', gZ, None, + Memlet.from_array(gZ, gZ.desc(sdfg))) + + return sdfg + + +def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + +@oprepo.replaces('numpy.einsum') +def create_einsum_sdfg(sdfg: SDFG, + state: SDFGState, + einsum_string: str, + *arrays: str, + dtype: Optional[dtypes.typeclass] = None, + optimize: bool = False, + output: Optional[str] = None): + return _create_einsum_internal(sdfg, state, einsum_string, *arrays, + dtype=dtype, optimize=optimize, + output=output)[0] + + +def _create_einsum_internal(sdfg: SDFG, + state: SDFGState, + einsum_string: str, + *arrays: str, + dtype: Optional[dtypes.typeclass] = None, + optimize: bool = False, + output: Optional[str] = None, + nodes: Optional[Dict[str, AccessNode]] = None): + # Infer shapes and strides of input/output arrays + einsum = EinsumParser(einsum_string) + + if len(einsum.inputs) != len(arrays): + raise ValueError('Invalid number of arrays for einsum expression') + + # Get shapes from arrays and verify dimensionality + chardict = {} + for inp, inpname in zip(einsum.inputs, arrays): + inparr = sdfg.arrays[inpname] + if len(inp) != len(inparr.shape): + raise ValueError('Dimensionality mismatch in input "%s"' % inpname) + for char, shp in zip(inp, inparr.shape): + if char in chardict and shp != chardict[char]: + raise ValueError('Dimension mismatch in einsum expression') + chardict[char] = shp + + if optimize: + # Try to import opt_einsum + try: + import opt_einsum as oe + except (ModuleNotFoundError, NameError, ImportError): + raise ImportError('To optimize einsum expressions, please install ' + 'the "opt_einsum" package.') + + for char, shp in chardict.items(): + if symbolic.issymbolic(shp): + raise ValueError('Einsum optimization cannot be performed ' + 'on symbolically-sized array dimension "%s" ' + 'for subscript character "%s"' % (shp, char)) + + # Create optimal contraction path + # noinspection PyTypeChecker + _, path_info = oe.contract_path( + einsum_string, *oe.helpers.build_views(einsum_string, chardict)) + + input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} + result_node = None + + # Follow path and create a chain of operation SDFG states + for pair, nonfree, expr, after, blas in path_info.contraction_list: + result, result_node = _create_einsum_internal( + sdfg, state, expr, arrays[pair[0]], arrays[pair[1]], + dtype=dtype, optimize=False, output=None, nodes=input_nodes) + arrays = ([a for i, a in enumerate(arrays) if i not in pair] + + [result]) + input_nodes[result] = result_node + + return arrays[0], result_node + # END of einsum optimization + + input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays} + + # Get output shape from chardict, or [1] for a scalar output + output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1] + output_index = ','.join(o for o in einsum.output) or '0' + + if output is None: + dtype = dtype or sdfg.arrays[arrays[0]].dtype + output, odesc = sdfg.add_temp_transient(output_shape, dtype) + to_init = True + else: + odesc = sdfg.arrays[output] + dtype = dtype or odesc.dtype + to_init = False + + if not einsum.is_bmm(): + # Fall back to "pure" SDFG einsum with conflict resolution + c = state.add_write(output) + + # Add state before this one to initialize the output value + if to_init: + init_state = sdfg.add_state_before(state) + init_state.add_mapped_tasklet( + 'einsum_reset', + {k: '0:%s' % chardict[k] for k in einsum.output}, + {}, 'out_%s = 0' % output, + {'out_%s' % output: Memlet.simple(output, output_index)}, + external_edges=True) + + # Pure einsum map + state.add_mapped_tasklet( + 'einsum', {k: '0:%s' % v + for k, v in chardict.items()}, + { + 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) + for inp, arr in zip(einsum.inputs, arrays) + }, + 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), + {'out_%s' % output: Memlet.simple(output, output_index, + wcr_str='lambda a,b: a+b')}, + input_nodes=nodes, + output_nodes={output: c}, + external_edges=True) + else: + # TODO: Only CUDA is supported + # Represent einsum as a GEMM or batched GEMM + a_shape = sdfg.arrays[arrays[0]].shape + b_shape = sdfg.arrays[arrays[1]].shape + c_shape = output_shape + + a = input_nodes[arrays[0]] + b = input_nodes[arrays[1]] + c = state.add_write(output) + + # Compute GEMM dimensions and strides + strides = dict( + BATCH=prod([c_shape[dim] for dim in einsum.c_batch]), + M=prod([a_shape[dim] for dim in einsum.a_only]), + K=prod([a_shape[dim] for dim in einsum.a_sum]), + N=prod([b_shape[dim] for dim in einsum.b_only]), + sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1, + sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1, + sAB=prod(a_shape[einsum.a_batch[-1] + + 1:]) if einsum.a_batch else 1, + sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1, + sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1, + sBB=prod(b_shape[einsum.b_batch[-1] + + 1:]) if einsum.b_batch else 1, + sCM=prod(c_shape[einsum.c_a_only[-1] + + 1:]) if einsum.c_a_only else 1, + sCN=prod(c_shape[einsum.c_b_only[-1] + + 1:]) if einsum.c_b_only else 1, + sCB=prod(c_shape[einsum.c_batch[-1] + + 1:]) if einsum.c_batch else 1) + + # Create nested SDFG for GEMM + nsdfg = create_batch_gemm_sdfg(dtype, strides) + + nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'}, + strides) + state.add_edge(a, None, nsdfg_node, 'X', + Memlet.from_array(a.data, a.desc(sdfg))) + state.add_edge(b, None, nsdfg_node, 'Y', + Memlet.from_array(b.data, b.desc(sdfg))) + state.add_edge(nsdfg_node, 'Z', c, None, + Memlet.from_array(c.data, c.desc(sdfg))) + + return output, c diff --git a/tests/numpy/einsum_test.py b/tests/numpy/einsum_test.py new file mode 100644 index 0000000000..099b937003 --- /dev/null +++ b/tests/numpy/einsum_test.py @@ -0,0 +1,77 @@ +import dace +import numpy as np + +M = dace.symbol('M') +N = dace.symbol('N') + + +def test_general_einsum(): + @dace.program + def einsumtest(A: dace.float64[M, N], B: dace.float64[N, M], + C: dace.float64[M]): + return np.einsum('ij,ji,i->', A, B, C) + + A = np.random.rand(10, 20) + B = np.random.rand(20, 10) + C = np.random.rand(10) + out = einsumtest(A, B, C) + assert np.allclose(out, np.einsum('ij,ji,i->', A, B, C)) + + +def test_matmul(): + @dace.program + def einsumtest(A: dace.float64[M, N], B: dace.float64[N, M]): + return np.einsum('ik,kj', A, B) + + A = np.random.rand(10, 20) + B = np.random.rand(20, 10) + assert np.allclose(einsumtest(A, B), A @ B) + + +def test_batch_matmul(): + @dace.program + def einsumtest(A: dace.float64[4, M, N], B: dace.float64[4, N, M]): + return np.einsum('bik,bkj->bij', A, B) + + A = np.random.rand(4, 10, 20) + B = np.random.rand(4, 20, 10) + assert np.allclose(einsumtest(A, B), A @ B) + + +def test_opteinsum_sym(): + @dace.program + def einsumtest(A: dace.float64[N, N, N, N], B: dace.float64[N, N, N, N], + C: dace.float64[N, N, N, N], D: dace.float64[N, N, N, N], + E: dace.float64[N, N, N, N]): + return np.einsum('bdik,acaj,ikab,ajac,ikbd->', A, B, C, D, E, + optimize=True) + + A, B, C, D, E = tuple(np.random.rand(10, 10, 10, 10) for _ in range(5)) + try: + einsumtest(A, B, C, D, E) + raise AssertionError('Exception should have been raised') + except ValueError: + print('Exception successfully caught') + + +def test_opteinsum(): + N = 10 + + @dace.program + def einsumtest(A: dace.float64[N, N, N, N], B: dace.float64[N, N, N, N], + C: dace.float64[N, N, N, N], D: dace.float64[N, N, N, N], + E: dace.float64[N, N, N, N]): + return np.einsum('bdik,acaj,ikab,ajac,ikbd->', A, B, C, D, E, + optimize=True) + + A, B, C, D, E = tuple(np.random.rand(10, 10, 10, 10) for _ in range(5)) + + assert np.allclose(einsumtest(A, B, C, D, E), + np.einsum('bdik,acaj,ikab,ajac,ikbd->', A, B, C, D, E)) + + +if __name__ == '__main__': + test_general_einsum() + test_matmul() + test_batch_matmul() + test_opteinsum() From 3f9401083595a12a5f985998d2665abdc14ba53b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 22:20:09 +0100 Subject: [PATCH 08/32] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 508d023e75..98266f0ba1 100644 --- a/setup.py +++ b/setup.py @@ -61,5 +61,5 @@ 'ply', 'websockets', 'graphviz', 'requests', 'flask', 'scikit-build', 'cmake', 'aenum' ], - tests_require=['coverage', 'scipy', 'absl-py', 'opt_einsum'], + extras_require={'testing': ['coverage', 'scipy', 'absl-py', 'opt_einsum']}, scripts=['scripts/diode', 'scripts/dacelab', 'scripts/sdfv']) From 20eac1ca6f8cde8539e293bf6393279c708956f1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 22:20:47 +0100 Subject: [PATCH 09/32] Update Jenkinsfile --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 5d2033d671..788b226677 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,7 +11,7 @@ pipeline { echo "Installing additional dependencies" pip3 install --upgrade --user tensorflow-gpu==1.14.0 echo "Installing DaCe" - pip3 install --ignore-installed --upgrade --user ".[test]" . + pip3 install --ignore-installed --upgrade --user ".[testing]" . pip3 install --user cmake pip3 install --user coverage ''' From c7dfbb97ef39941765fd449a5408b9fa5dc5dde3 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 15 Mar 2020 22:21:00 +0100 Subject: [PATCH 10/32] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 7ee6a4a5db..6058528796 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,7 @@ before_install: - sudo apt-get install libpapi-dev papi-tools install: - - pip install ".[test]" . + - pip install ".[testing]" . - pip install coverage codecov - if [ $DACE_optimizer_automatic_strict_transformations -eq 1 ]; then pip install tensorflow==1.15.0; fi From 184964ccbd456bda94be2bc836d67a6ed5956f99 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 12:22:32 +0100 Subject: [PATCH 11/32] Refactor BLAS helpers into a file --- dace/frontend/common/einsum.py | 109 +------------------ dace/frontend/common/op_impl.py | 19 +--- dace/libraries/blas/blas_helpers.py | 142 +++++++++++++++++++++++++ dace/libraries/blas/nodes/matmul.py | 2 +- dace/libraries/blas/nodes/transpose.py | 1 - 5 files changed, 147 insertions(+), 126 deletions(-) create mode 100644 dace/libraries/blas/blas_helpers.py diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index 9cb1a8fb69..c840811331 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -10,7 +10,7 @@ from dace.sdfg import SDFG, SDFGState from dace.memlet import Memlet from dace.frontend.common import op_repository as oprepo -from dace.frontend.common.op_impl import _to_blastype +from dace.libraries.blas.blas_helpers import to_blastype, get_gemm_opts def _is_sequential(index_list): @@ -146,110 +146,7 @@ def create_batch_gemm_sdfg(dtype, strides): gY = state.add_read('Y') gZ = state.add_access('Z') - # possible order (C, row based) of dimensions in input array - # and computed result based on - # 1. N/T - transpose flag in cublas - # 2. LR/RL - order in which A and B are passed into cublas - # k m, n k -> n m (LR, N, N) - # m k, n k -> n m (LR, T, N) - # k m, k n -> n m (LR, N, T) - # m k, k n -> n m (LR, T, T) - # m k, k n -> m n (RL, N, N) - # m k, n k -> m n (RL, N, T) - # k m, k n -> m n (RL, T, N) - # k m, n k -> m n (RL, T, T) - # | | | - # use these 3 to detect correct option - - opts = { - 'mkm': { - 'swap': False, - 'lda': sAK, - 'ldb': sBN, - 'ldc': sCN, - 'ta': 'N', - 'tb': 'N' - }, - 'kkm': { - 'swap': False, - 'lda': sAM, - 'ldb': sBN, - 'ldc': sCN, - 'ta': 'T', - 'tb': 'N' - }, - 'mnm': { - 'swap': False, - 'lda': sAK, - 'ldb': sBK, - 'ldc': sCN, - 'ta': 'N', - 'tb': 'T' - }, - 'knm': { - 'swap': False, - 'lda': sAM, - 'ldb': sBK, - 'ldc': sCN, - 'ta': 'T', - 'tb': 'T' - }, - 'knn': { - 'swap': True, - 'lda': sAM, - 'ldb': sBK, - 'ldc': sCM, - 'ta': 'N', - 'tb': 'N' - }, - 'kkn': { - 'swap': True, - 'lda': sAM, - 'ldb': sBN, - 'ldc': sCM, - 'ta': 'N', - 'tb': 'T' - }, - 'mnn': { - 'swap': True, - 'lda': sAK, - 'ldb': sBK, - 'ldc': sCM, - 'ta': 'T', - 'tb': 'N' - }, - 'mkn': { - 'swap': True, - 'lda': sAK, - 'ldb': sBN, - 'ldc': sCM, - 'ta': 'T', - 'tb': 'T' - }, - } - - if strides['sAM'] == 1: - optA = 'm' - elif strides['sAK'] == 1: - optA = 'k' - else: - raise Exception("sAM or sAK should be 1") - - if strides['sBK'] == 1: - optB = 'k' - elif strides['sBN'] == 1: - optB = 'n' - else: - raise Exception("sBK or sBN should be 1") - - if strides['sCM'] == 1: - optC = 'm' - elif strides['sCN'] == 1: - optC = 'n' - else: - raise Exception("sCM or sCN should be 1") - - opt = opts[optA + optB + optC] + opt = get_gemm_opts(xarr, yarr, zarr) opt['sta'] = sAB opt['stb'] = sBB @@ -279,7 +176,7 @@ def create_batch_gemm_sdfg(dtype, strides): code_init = '' code_exit = '' - cublas_gemm = 'cublas%sgemm' % _to_blastype(dtype.type) + cublas_gemm = 'cublas%sgemm' % to_blastype(dtype.type) if not batched: code = ''' diff --git a/dace/frontend/common/op_impl.py b/dace/frontend/common/op_impl.py index fccf4f6b41..c2fe30bdad 100644 --- a/dace/frontend/common/op_impl.py +++ b/dace/frontend/common/op_impl.py @@ -7,6 +7,7 @@ from dace import symbolic import typing import numpy as np +from dace.libraries.blas.blas_helpers import to_blastype as _to_blastype State = dace.sdfg.SDFGState Shape = typing.List[typing.Union[int, symbolic.symbol]] @@ -16,24 +17,6 @@ # TODO: Most of the external operations here emit Z (complex double) ops, fix -def _to_blastype(dtype): - """ Returns a BLAS character that corresponds to the input type. - Used in MKL/CUBLAS calls. """ - - if dtype == np.float16: - return 'H' - elif dtype == np.float32: - return 'S' - elif dtype == np.float64: - return 'D' - elif dtype == np.complex64: - return 'C' - elif dtype == np.complex128: - return 'Z' - else: - raise TypeError('Type %s not supported in BLAS operations' % - dtype.__name__) - def _to_cudatype(dtype): """ Returns a CUDA typename that corresponds to the input type. diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py new file mode 100644 index 0000000000..ad749dc2a0 --- /dev/null +++ b/dace/libraries/blas/blas_helpers.py @@ -0,0 +1,142 @@ +import numpy as np +from dace.data import Array +from typing import Any, Dict + +def to_blastype(dtype): + """ Returns a BLAS character that corresponds to the input type. + Used in MKL/CUBLAS calls. """ + + if dtype == np.float16: + return 'H' + elif dtype == np.float32: + return 'S' + elif dtype == np.float64: + return 'D' + elif dtype == np.complex64: + return 'C' + elif dtype == np.complex128: + return 'Z' + else: + raise TypeError('Type %s not supported in BLAS operations' % + dtype.__name__) + + +def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]: + """ + Returns GEMM argument order, transposition, and leading dimensions + based on column-major storage from dace arrays. + :param a: Data descriptor for the first matrix. + :param b: Data descriptor for the second matrix. + :param c: Data descriptor for the output matrix. + :return: A dictionary with the following keys: swap (if True, a and b + should be swapped); lda, ldb, ldc (leading dimensions); ta, tb + (whether GEMM should be called with OP_N or OP_T). + """ + # possible order (C, row based) of dimensions in input array + # and computed result based on + # 1. N/T - transpose flag in cublas + # 2. LR/RL - order in which A and B are passed into cublas + # k m, n k -> n m (LR, N, N) + # m k, n k -> n m (LR, T, N) + # k m, k n -> n m (LR, N, T) + # m k, k n -> n m (LR, T, T) + # m k, k n -> m n (RL, N, N) + # m k, n k -> m n (RL, N, T) + # k m, k n -> m n (RL, T, N) + # k m, n k -> m n (RL, T, T) + # | | | + # use these 3 to detect correct option + + sAM, sAK = a.strides[-2:] + sBK, sBN = b.strides[-2:] + sCM, sCN = c.strides[-2:] + + opts = { + 'mkm': { + 'swap': False, + 'lda': sAK, + 'ldb': sBN, + 'ldc': sCN, + 'ta': 'N', + 'tb': 'N' + }, + 'kkm': { + 'swap': False, + 'lda': sAM, + 'ldb': sBN, + 'ldc': sCN, + 'ta': 'T', + 'tb': 'N' + }, + 'mnm': { + 'swap': False, + 'lda': sAK, + 'ldb': sBK, + 'ldc': sCN, + 'ta': 'N', + 'tb': 'T' + }, + 'knm': { + 'swap': False, + 'lda': sAM, + 'ldb': sBK, + 'ldc': sCN, + 'ta': 'T', + 'tb': 'T' + }, + 'knn': { + 'swap': True, + 'lda': sAM, + 'ldb': sBK, + 'ldc': sCM, + 'ta': 'N', + 'tb': 'N' + }, + 'kkn': { + 'swap': True, + 'lda': sAM, + 'ldb': sBN, + 'ldc': sCM, + 'ta': 'N', + 'tb': 'T' + }, + 'mnn': { + 'swap': True, + 'lda': sAK, + 'ldb': sBK, + 'ldc': sCM, + 'ta': 'T', + 'tb': 'N' + }, + 'mkn': { + 'swap': True, + 'lda': sAK, + 'ldb': sBN, + 'ldc': sCM, + 'ta': 'T', + 'tb': 'T' + }, + } + + if sAM == 1: + optA = 'm' + elif sAK == 1: + optA = 'k' + else: + raise Exception("sAM or sAK should be 1") + + if sBK == 1: + optB = 'k' + elif sBN == 1: + optB = 'n' + else: + raise Exception("sBK or sBN should be 1") + + if sCM == 1: + optC = 'm' + elif sCN == 1: + optC = 'n' + else: + raise Exception("sCM or sCN should be 1") + + return opts[optA + optB + optC] diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 6bca4c3866..e0f7e93ef8 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -1,11 +1,11 @@ from copy import deepcopy as dc import numpy as np from dace.config import Config -from dace.frontend.common.op_impl import gpu_transform_tasklet import dace.library import dace.properties import dace.graph.nodes from dace.transformation.pattern_matching import ExpandTransformation +from dace.libraries.blas.blas_helpers import to_blastype, get_gemm_opts from .. import environments diff --git a/dace/libraries/blas/nodes/transpose.py b/dace/libraries/blas/nodes/transpose.py index 40aedabbb0..cbde5b1672 100644 --- a/dace/libraries/blas/nodes/transpose.py +++ b/dace/libraries/blas/nodes/transpose.py @@ -1,7 +1,6 @@ import functools from copy import deepcopy as dc from dace.config import Config -from dace.frontend.common.op_impl import gpu_transform_tasklet import dace.library import dace.properties import dace.graph.nodes From 5f9fe720e308f2943e574a203dfc99dc027fcba6 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 12:22:56 +0100 Subject: [PATCH 12/32] Fix minor issues in einsum --- dace/frontend/common/einsum.py | 58 ++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 23 deletions(-) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index c840811331..d2c88982cd 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -119,28 +119,32 @@ def create_batch_gemm_sdfg(dtype, strides): sdfg = SDFG('einsum') state = sdfg.add_state() BATCH, M, K, N, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( - symbolic.symbol(s) for s in [ + #symbolic.symbol(s) for s in [ + strides[s] for s in [ 'BATCH', 'M', 'K', 'N', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN', 'sCB' ]) batched = strides['BATCH'] != 1 - sdfg.add_array('X', - dtype=dtype, - shape=[BATCH, M, K] if batched else [M, K], - strides=[sAB, sAM, sAK] if batched else [sAM, sAK], - storage=dtypes.StorageType.GPU_Global) - sdfg.add_array('Y', - dtype=dtype, - shape=[BATCH, K, N] if batched else [K, N], - strides=[sBB, sBK, sBN] if batched else [sBK, sBN], - storage=dtypes.StorageType.GPU_Global) - sdfg.add_array('Z', - dtype=dtype, - shape=[BATCH, M, N] if batched else [M, N], - strides=[sCB, sCM, sCN] if batched else [sCM, sCN], - storage=dtypes.StorageType.GPU_Global) + _, xarr = sdfg.add_array('X', + dtype=dtype, + shape=[BATCH, M, K] if batched else [M, K], + strides=[sAB, sAM, sAK] if batched else [ + sAM, sAK], + storage=dtypes.StorageType.GPU_Global) + _, yarr = sdfg.add_array('Y', + dtype=dtype, + shape=[BATCH, K, N] if batched else [K, N], + strides=[sBB, sBK, sBN] if batched else [ + sBK, sBN], + storage=dtypes.StorageType.GPU_Global) + _, zarr = sdfg.add_array('Z', + dtype=dtype, + shape=[BATCH, M, N] if batched else [M, N], + strides=[sCB, sCM, sCN] if batched else [ + sCM, sCN], + storage=dtypes.StorageType.GPU_Global) gX = state.add_read('X') gY = state.add_read('Y') @@ -348,12 +352,20 @@ def _create_einsum_internal(sdfg: SDFG, # Add state before this one to initialize the output value if to_init: init_state = sdfg.add_state_before(state) - init_state.add_mapped_tasklet( - 'einsum_reset', - {k: '0:%s' % chardict[k] for k in einsum.output}, - {}, 'out_%s = 0' % output, - {'out_%s' % output: Memlet.simple(output, output_index)}, - external_edges=True) + if len(einsum.output) > 0: + init_state.add_mapped_tasklet( + 'einsum_reset', + {k: '0:%s' % chardict[k] for k in einsum.output}, + {}, 'out_%s = 0' % output, + {'out_%s' % output: Memlet.simple(output, output_index)}, + external_edges=True) + else: # Scalar output + t = init_state.add_tasklet('einsum_reset', {}, + {'out_%s' % output}, + 'out_%s = 0' % output) + onode = init_state.add_write(output) + init_state.add_edge(t, 'out_%s' % output, onode, None, + Memlet.simple(output, '0')) # Pure einsum map state.add_mapped_tasklet( @@ -366,7 +378,7 @@ def _create_einsum_internal(sdfg: SDFG, 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), {'out_%s' % output: Memlet.simple(output, output_index, wcr_str='lambda a,b: a+b')}, - input_nodes=nodes, + input_nodes=input_nodes, output_nodes={output: c}, external_edges=True) else: From 37eae5f50f6cad5465ad5e389b93f84410893b90 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 12:23:38 +0100 Subject: [PATCH 13/32] CUBLAS matrix multiplication + test --- dace/libraries/blas/include/dace_cublas.h | 42 ++----- dace/libraries/blas/nodes/matmul.py | 132 +++++++++++++++++++- tests/library/matmul_cudatest.py | 139 ++++++++++++++++++++++ 3 files changed, 278 insertions(+), 35 deletions(-) create mode 100644 tests/library/matmul_cudatest.py diff --git a/dace/libraries/blas/include/dace_cublas.h b/dace/libraries/blas/include/dace_cublas.h index e6429a7864..58968aec61 100644 --- a/dace/libraries/blas/include/dace_cublas.h +++ b/dace/libraries/blas/include/dace_cublas.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include // size_t @@ -66,10 +67,10 @@ namespace { class _CublasConstants { public: - float const* FloatZero() const { return float_zero_; } - double const* DoubleZero() const { return double_zero_; } - cuComplex const* Complex64Zero() const { return complex64_zero_; } - cuDoubleComplex const* Complex128Zero() const { return complex128_zero_; } + float const* FloatZero() const { return (float*)zero_; } + double const* DoubleZero() const { return (double*)zero_; } + cuComplex const* Complex64Zero() const { return (cuComplex*)zero_; } + cuDoubleComplex const* Complex128Zero() const { return (cuDoubleComplex*)zero_; } float const* FloatPone() const { return float_pone_; } double const* DoublePone() const { return double_pone_; } cuComplex const* Complex64Pone() const { return complex64_pone_; } @@ -79,30 +80,17 @@ class _CublasConstants { if (cudaSetDevice(device) != cudaSuccess) { throw std::runtime_error("Failed to set CUDA device."); } - // Allocate constant zero - cudaMalloc(&float_zero_, sizeof(float) * 1); - float float_zero = 0.0f; - cudaMemcpy(float_zero_, &float_zero, sizeof(float) * 1, - cudaMemcpyHostToDevice); - cudaMalloc(&double_zero_, sizeof(double) * 1); - double double_zero = 0.0; - cudaMemcpy(double_zero_, &double_zero, sizeof(double) * 1, - cudaMemcpyHostToDevice); - cudaMalloc(&complex64_zero_, sizeof(cuComplex) * 1); - cuComplex complex64_zero = make_cuComplex(0.0f, 0.0f); - cudaMemcpy(complex64_zero_, &complex64_zero, sizeof(cuComplex) * 1, - cudaMemcpyHostToDevice); - cudaMalloc(&complex128_zero_, sizeof(cuDoubleComplex) * 1); - cuDoubleComplex complex128_zero = make_cuDoubleComplex(0.0, 0.0); - cudaMemcpy(complex128_zero_, &complex128_zero, sizeof(cuDoubleComplex) * 1, - cudaMemcpyHostToDevice); + // Allocate constant zero with the largest used size + cudaMalloc(&zero_, sizeof(cuDoubleComplex) * 1); + cudaMemset(zero_, 0, sizeof(cuDoubleComplex) * 1); + // Allocate constant one cudaMalloc(&float_pone_, sizeof(float) * 1); float float_pone = 1.0f; cudaMemcpy(float_pone_, &float_pone, sizeof(float) * 1, cudaMemcpyHostToDevice); cudaMalloc(&double_pone_, sizeof(double) * 1); - double double_pone = 0.0; + double double_pone = 1.0; cudaMemcpy(double_pone_, &double_pone, sizeof(double) * 1, cudaMemcpyHostToDevice); cudaMalloc(&complex64_pone_, sizeof(cuComplex) * 1); @@ -118,10 +106,7 @@ class _CublasConstants { _CublasConstants(_CublasConstants const&) = delete; ~_CublasConstants() { - cudaFree(float_zero_); - cudaFree(double_zero_); - cudaFree(complex64_zero_); - cudaFree(complex128_zero_); + cudaFree(zero_); cudaFree(float_pone_); cudaFree(double_pone_); cudaFree(complex64_pone_); @@ -137,10 +122,7 @@ class _CublasConstants { } } - float* float_zero_; - double* double_zero_; - cuComplex* complex64_zero_; - cuDoubleComplex* complex128_zero_; + void* zero_; float* float_pone_; double* double_pone_; cuComplex* complex64_pone_; diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index e0f7e93ef8..fbccdb6b60 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -105,20 +105,17 @@ class ExpandMatMulMKL(ExpandTransformation): def expansion(node, state, sdfg): node.validate(sdfg, state) dtype = node.dtype + func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: - func = "sgemm" alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: - func = "dgemm" alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: - func = "cgemm" alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: - func = "zgemm" alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: @@ -136,11 +133,136 @@ def expansion(node, state, sdfg): return tasklet +@dace.library.expansion +class ExpandMatMulMKL(ExpandTransformation): + + environments = [environments.intel_mkl.IntelMKL] + + @staticmethod + def expansion(node, state, sdfg): + node.validate(sdfg, state) + dtype = node.dtype + func = to_blastype(dtype.type).lower() + 'gemm' + if dtype == dace.float32: + alpha = "1.0f" + beta = "0.0f" + elif dtype == dace.float64: + alpha = "1.0" + beta = "0.0" + elif dtype == dace.complex64: + alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" + beta = "dace::blas::BlasConstants::Get().Complex64Zero()" + elif dtype == dace.complex128: + alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" + beta = "dace::blas::BlasConstants::Get().Complex128Zero()" + else: + raise ValueError("Unsupported type for BLAS dot product: " + + str(dtype)) + (_, _, (m, k)), (_, _, (_, n)) = _get_matmul_inputs(node, state, sdfg) + code = ("cblas_{f}(CblasRowMajor, CblasNoTrans, CblasNoTrans, " + "{m}, {n}, {k}, {a}, _a, {k}, _b, {n}, {b}, _c, {n});").format( + f=func, m=m, n=n, k=k, a=alpha, b=beta) + tasklet = dace.graph.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP) + return tasklet + + +@dace.library.expansion +class ExpandMatMulCuBLAS(ExpandTransformation): + + environments = [environments.cublas.cuBLAS] + + @staticmethod + def expansion(node, state, sdfg): + gpuid = node.location or '0' + node.validate(sdfg, state) + dtype = node.dtype + func = '%sgemm' % to_blastype(dtype.type) + if dtype == dace.float32: + factort = 'Float' + elif dtype == dace.float64: + factort = 'Double' + elif dtype == dace.complex64: + factort = 'Complex64' + elif dtype == dace.complex128: + factort = 'Complex128' + else: + raise ValueError("Unsupported type: " + str(dtype)) + + alpha = "dace::blas::CublasConstants::Get(__dace_cuda_device).%sPone()" % factort + beta = "dace::blas::CublasConstants::Get(__dace_cuda_device).%sZero()" % factort + + # Find inputs and output + adesc, bdesc, cdesc = None, None, None + for e in state.in_edges(node): + if e.dst_conn == '_a': + anode = state.memlet_path(e)[0].src + if isinstance(anode, dace.graph.nodes.AccessNode): + adesc = sdfg.arrays[anode.data] + elif e.dst_conn == '_b': + bnode = state.memlet_path(e)[0].src + if isinstance(bnode, dace.graph.nodes.AccessNode): + bdesc = sdfg.arrays[bnode.data] + for e in state.out_edges(node): + if e.src_conn == '_c': + cnode = state.memlet_path(e)[-1].dst + if isinstance(cnode, dace.graph.nodes.AccessNode): + cdesc = sdfg.arrays[cnode.data] + if not adesc or not bdesc or not cdesc: + raise ValueError('Unsupported input/output arrays') + + (_, _, (m, k)), (_, _, (_, n)) = _get_matmul_inputs(node, state, sdfg) + opt = get_gemm_opts(adesc, bdesc, cdesc) + opt['x'] = '_a' + opt['y'] = '_b' + opt['M'] = m + opt['N'] = n + if opt['swap']: + opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] + opt['x'], opt['y'] = opt['y'], opt['x'] + opt['ta'], opt['tb'] = opt['tb'], opt['ta'] + opt['M'], opt['N'] = opt['N'], opt['M'] + + + code = environments.cublas.cuBLAS.handle_setup_code(node) + ''' + cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + {alpha}, + {x}, {lda}, + {y}, {ldb}, + {beta}, + _c, {ldc}); + '''.format(M=opt['M'], + N=opt['N'], + K=k, + lda=opt['lda'], + ldb=opt['ldb'], + ldc=opt['ldc'], + x=opt['x'], + y=opt['y'], + ta=opt['ta'], + tb=opt['tb'], + alpha=alpha, + beta=beta, + c_dtype=dtype.ctype, + func=func) + tasklet = dace.graph.nodes.Tasklet(node.name, + node.in_connectors, + node.out_connectors, + code, + language=dace.dtypes.Language.CPP) + return tasklet + + @dace.library.node class MatMul(dace.graph.nodes.LibraryNode): # Global properties - implementations = {"pure": ExpandMatMulPure, "MKL": ExpandMatMulMKL} + implementations = {"pure": ExpandMatMulPure, "MKL": ExpandMatMulMKL, + "cuBLAS": ExpandMatMulCuBLAS} default_implementation = None # Object fields diff --git a/tests/library/matmul_cudatest.py b/tests/library/matmul_cudatest.py new file mode 100644 index 0000000000..a3efdfd7af --- /dev/null +++ b/tests/library/matmul_cudatest.py @@ -0,0 +1,139 @@ +import dace +from dace.memlet import Memlet +from dace.codegen.compiler import CompilerConfigurationError, CompilationError +import dace.libraries.blas as blas +import numpy as np +import sys +import warnings + +############################################################################### + + +def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): + m = dace.symbol("m") + n = dace.symbol("n") + k = dace.symbol("k") + + suffix = "_device" if storage != dace.StorageType.Default else "" + transient = storage != dace.StorageType.Default + + sdfg = dace.SDFG("cublasgemm_{}".format(dtype)) + state = sdfg.add_state("dataflow") + + sdfg.add_array("x" + suffix, [m, k], + dtype, + storage=storage, + transient=transient) + sdfg.add_array("y" + suffix, [k, n], + dtype, + storage=storage, + transient=transient) + sdfg.add_array("result" + suffix, [m, n], + dtype, + storage=storage, + transient=transient) + + x = state.add_read("x" + suffix) + y = state.add_read("y" + suffix) + result = state.add_write("result" + suffix) + + node = blas.nodes.matmul.MatMul("matmul", dtype) + node.implementation = implementation + + state.add_memlet_path(x, + node, + dst_conn="_a", + memlet=Memlet.simple(x, "0:m, 0:k")) + state.add_memlet_path(y, + node, + dst_conn="_b", + memlet=Memlet.simple(y, "0:k, 0:n")) + # TODO: remove -1 once this no longer triggers a write in the codegen. + state.add_memlet_path(node, + result, + src_conn="_c", + memlet=Memlet.simple(result, "0:m, 0:n")) + + if storage != dace.StorageType.Default: + sdfg.add_array("x", [m, k], dtype) + sdfg.add_array("y", [k, n], dtype) + sdfg.add_array("result", [m, n], dtype) + + init_state = sdfg.add_state("copy_to_device") + sdfg.add_edge(init_state, state, dace.InterstateEdge()) + + x_host = init_state.add_read("x") + y_host = init_state.add_read("y") + x_device = init_state.add_write("x" + suffix) + y_device = init_state.add_write("y" + suffix) + init_state.add_memlet_path(x_host, + x_device, + memlet=Memlet.simple(x_host, + "0:m, 0:k")) + init_state.add_memlet_path(y_host, + y_device, + memlet=Memlet.simple(y_host, + "0:k, 0:n")) + + finalize_state = sdfg.add_state("copy_to_host") + sdfg.add_edge(state, finalize_state, dace.InterstateEdge()) + + result_device = finalize_state.add_write("result" + suffix) + result_host = finalize_state.add_read("result") + finalize_state.add_memlet_path(result_device, + result_host, + memlet=Memlet.simple(result_device, + "0:m, 0:n")) + + return sdfg + + +############################################################################### + + +def _test_matmul(implementation, dtype, sdfg): + try: + csdfg = sdfg.compile() + except (CompilerConfigurationError, CompilationError): + warnings.warn( + 'Configuration/compilation failed, library missing or ' + 'misconfigured, skipping test for {}.'.format(implementation)) + return + + m, n, k = 32, 31, 30 + + x = np.ndarray([m, k], dtype=dtype) + y = np.ndarray([k, n], dtype=dtype) + result = np.ndarray([m, n], dtype=dtype) + + x[:] = 2.5 + y[:] = 2 + result[:] = 0 + + csdfg(x=x, y=y, result=result, m=m, n=n, k=k) + + ref = np.dot(x, y) + + diff = np.linalg.norm(ref - result) + if diff >= 1e-6: + print("Unexpected result returned from dot product: " + "diff %f" % diff) + sys.exit(1) + + print("Test ran successfully for {}.".format(implementation)) + + +def test_matmul(): + _test_matmul( + "32-bit cuBLAS", np.float32, + make_sdfg("cuBLAS", dace.float32, dace.StorageType.GPU_Global)) + _test_matmul( + "64-bit cuBLAS", np.float64, + make_sdfg("cuBLAS", dace.float64, dace.StorageType.GPU_Global)) + + +############################################################################### + +if __name__ == "__main__": + test_matmul() +############################################################################### From 661560ed4d0b7ed67b6a842ddee9d74eec7822dd Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 12:29:52 +0100 Subject: [PATCH 14/32] Added matmul test to CI --- tests/cuda_test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/cuda_test.sh b/tests/cuda_test.sh index 7e7970418d..5e2897da5a 100755 --- a/tests/cuda_test.sh +++ b/tests/cuda_test.sh @@ -184,6 +184,7 @@ runall() { runopt samples/simple/axpy.py $1 'GPUTransformSDFG$0' runtestargs instrumentation_test.py gpu + runtestargs library/matmul_cudatest.py } From f1266f0ef0f8d3cfc50c789a14891ae8262e6cda Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 14:26:56 +0100 Subject: [PATCH 15/32] matmul tests are more exhaustive (data layout, types) --- dace/libraries/blas/include/dace_cublas.h | 8 +++ dace/libraries/blas/nodes/matmul.py | 17 +++-- tests/library/matmul_cudatest.py | 87 ++++++++++++++--------- 3 files changed, 75 insertions(+), 37 deletions(-) diff --git a/dace/libraries/blas/include/dace_cublas.h b/dace/libraries/blas/include/dace_cublas.h index 58968aec61..ef67c9abb9 100644 --- a/dace/libraries/blas/include/dace_cublas.h +++ b/dace/libraries/blas/include/dace_cublas.h @@ -67,10 +67,12 @@ namespace { class _CublasConstants { public: + __half const* HalfZero() const { return (__half*)zero_; } float const* FloatZero() const { return (float*)zero_; } double const* DoubleZero() const { return (double*)zero_; } cuComplex const* Complex64Zero() const { return (cuComplex*)zero_; } cuDoubleComplex const* Complex128Zero() const { return (cuDoubleComplex*)zero_; } + __half const* HalfPone() const { return half_pone_; } float const* FloatPone() const { return float_pone_; } double const* DoublePone() const { return double_pone_; } cuComplex const* Complex64Pone() const { return complex64_pone_; } @@ -85,6 +87,10 @@ class _CublasConstants { cudaMemset(zero_, 0, sizeof(cuDoubleComplex) * 1); // Allocate constant one + cudaMalloc(&half_pone_, sizeof(__half) * 1); + __half half_pone = __float2half(1.0f); + cudaMemcpy(half_pone_, &half_pone, sizeof(__half) * 1, + cudaMemcpyHostToDevice); cudaMalloc(&float_pone_, sizeof(float) * 1); float float_pone = 1.0f; cudaMemcpy(float_pone_, &float_pone, sizeof(float) * 1, @@ -107,6 +113,7 @@ class _CublasConstants { ~_CublasConstants() { cudaFree(zero_); + cudaFree(half_pone_); cudaFree(float_pone_); cudaFree(double_pone_); cudaFree(complex64_pone_); @@ -123,6 +130,7 @@ class _CublasConstants { } void* zero_; + __half* half_pone_; float* float_pone_; double* double_pone_; cuComplex* complex64_pone_; diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index fbccdb6b60..c1b32d12c2 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -181,13 +181,20 @@ def expansion(node, state, sdfg): node.validate(sdfg, state) dtype = node.dtype func = '%sgemm' % to_blastype(dtype.type) - if dtype == dace.float32: + if dtype == dace.float16: + cdtype = '__half' + factort = 'Half' + elif dtype == dace.float32: + cdtype = 'float' factort = 'Float' elif dtype == dace.float64: + cdtype = 'double' factort = 'Double' elif dtype == dace.complex64: + cdtype = 'cuComplex' factort = 'Complex64' elif dtype == dace.complex128: + cdtype = 'cuDoubleComplex' factort = 'Complex128' else: raise ValueError("Unsupported type: " + str(dtype)) @@ -231,10 +238,10 @@ def expansion(node, state, sdfg): cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, {alpha}, - {x}, {lda}, - {y}, {ldb}, + ({dtype}*){x}, {lda}, + ({dtype}*){y}, {ldb}, {beta}, - _c, {ldc}); + ({dtype}*)_c, {ldc}); '''.format(M=opt['M'], N=opt['N'], K=k, @@ -247,7 +254,7 @@ def expansion(node, state, sdfg): tb=opt['tb'], alpha=alpha, beta=beta, - c_dtype=dtype.ctype, + dtype=cdtype, func=func) tasklet = dace.graph.nodes.Tasklet(node.name, node.in_connectors, diff --git a/tests/library/matmul_cudatest.py b/tests/library/matmul_cudatest.py index a3efdfd7af..a883b38ead 100644 --- a/tests/library/matmul_cudatest.py +++ b/tests/library/matmul_cudatest.py @@ -2,6 +2,7 @@ from dace.memlet import Memlet from dace.codegen.compiler import CompilerConfigurationError, CompilationError import dace.libraries.blas as blas +import itertools import numpy as np import sys import warnings @@ -9,7 +10,8 @@ ############################################################################### -def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): +def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, + data_layout='CCC'): m = dace.symbol("m") n = dace.symbol("n") k = dace.symbol("k") @@ -17,21 +19,31 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): suffix = "_device" if storage != dace.StorageType.Default else "" transient = storage != dace.StorageType.Default - sdfg = dace.SDFG("cublasgemm_{}".format(dtype)) + sdfg = dace.SDFG("cublasgemm_{}_{}".format(dtype.type.__name__, data_layout)) state = sdfg.add_state("dataflow") + # Data layout is a 3-character string with either C (for row major) + # or F (for column major) matrices for x, y, and z respectively. + xstrides = (k, 1) if data_layout[0] == 'C' else (1, m) + ystrides = (n, 1) if data_layout[1] == 'C' else (1, k) + zstrides = (n, 1) if data_layout[2] == 'C' else (1, m) + + sdfg.add_array("x" + suffix, [m, k], dtype, storage=storage, - transient=transient) + transient=transient, + strides=xstrides) sdfg.add_array("y" + suffix, [k, n], dtype, storage=storage, - transient=transient) + transient=transient, + strides=ystrides) sdfg.add_array("result" + suffix, [m, n], dtype, storage=storage, - transient=transient) + transient=transient, + strides=zstrides) x = state.add_read("x" + suffix) y = state.add_read("y" + suffix) @@ -48,7 +60,6 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): node, dst_conn="_b", memlet=Memlet.simple(y, "0:k, 0:n")) - # TODO: remove -1 once this no longer triggers a write in the codegen. state.add_memlet_path(node, result, src_conn="_c", @@ -91,31 +102,27 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): ############################################################################### -def _test_matmul(implementation, dtype, sdfg): - try: - csdfg = sdfg.compile() - except (CompilerConfigurationError, CompilationError): - warnings.warn( - 'Configuration/compilation failed, library missing or ' - 'misconfigured, skipping test for {}.'.format(implementation)) - return +def _test_matmul(implementation, dtype, impl_name, storage, + data_layout='CCC', eps=1e-6): + sdfg = make_sdfg(impl_name, dtype, storage, data_layout) + csdfg = sdfg.compile(optimizer=False) m, n, k = 32, 31, 30 - x = np.ndarray([m, k], dtype=dtype) - y = np.ndarray([k, n], dtype=dtype) - result = np.ndarray([m, n], dtype=dtype) + x = np.ndarray([m, k], dtype=dtype.type, order=data_layout[0]) + y = np.ndarray([k, n], dtype=dtype.type, order=data_layout[1]) + z = np.ndarray([m, n], dtype=dtype.type, order=data_layout[2]) - x[:] = 2.5 - y[:] = 2 - result[:] = 0 + x[:] = np.random.rand(m, k) + y[:] = np.random.rand(k, n) + z[:] = 0 - csdfg(x=x, y=y, result=result, m=m, n=n, k=k) + csdfg(x=x, y=y, result=z, m=m, n=n, k=k) ref = np.dot(x, y) - diff = np.linalg.norm(ref - result) - if diff >= 1e-6: + diff = np.linalg.norm(ref - z) + if diff >= eps: print("Unexpected result returned from dot product: " "diff %f" % diff) sys.exit(1) @@ -123,17 +130,33 @@ def _test_matmul(implementation, dtype, sdfg): print("Test ran successfully for {}.".format(implementation)) -def test_matmul(): - _test_matmul( - "32-bit cuBLAS", np.float32, - make_sdfg("cuBLAS", dace.float32, dace.StorageType.GPU_Global)) - _test_matmul( - "64-bit cuBLAS", np.float64, - make_sdfg("cuBLAS", dace.float64, dace.StorageType.GPU_Global)) +def test_types(): + # Try different data types + _test_matmul('cuBLAS double', dace.float64, 'cuBLAS', + dace.StorageType.GPU_Global) + _test_matmul('cuBLAS half', dace.float16, 'cuBLAS', + dace.StorageType.GPU_Global, eps=1) + _test_matmul('cuBLAS scmplx', dace.complex64, 'cuBLAS', + dace.StorageType.GPU_Global) + _test_matmul('cuBLAS dcmplx', dace.complex128, 'cuBLAS', + dace.StorageType.GPU_Global) +def test_layouts(): + # Try all data layouts + for dl in map(lambda t: ''.join(t), itertools.product(*([['C', 'F']]*3))): + _test_matmul('cuBLAS float ' + dl, dace.float32, 'cuBLAS', + dace.StorageType.GPU_Global, data_layout=dl) ############################################################################### -if __name__ == "__main__": - test_matmul() +if __name__ == '__main__': + import os + try: + test_types() + test_layouts() + except SystemExit as ex: + print('\n', flush=True) + # Skip all teardown to avoid crashes affecting exit code + os._exit(ex.code) + os._exit(0) ############################################################################### From 4a1fe4a1d93930f6d69f5099f46751ce1625d013 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 19:22:46 +0100 Subject: [PATCH 16/32] Fix for unsupported case in CI --- tests/library/matmul_cudatest.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/library/matmul_cudatest.py b/tests/library/matmul_cudatest.py index a883b38ead..a9e2bc0e89 100644 --- a/tests/library/matmul_cudatest.py +++ b/tests/library/matmul_cudatest.py @@ -19,7 +19,7 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, suffix = "_device" if storage != dace.StorageType.Default else "" transient = storage != dace.StorageType.Default - sdfg = dace.SDFG("cublasgemm_{}_{}".format(dtype.type.__name__, data_layout)) + sdfg = dace.SDFG("mm_{}_{}".format(dtype.type.__name__, data_layout)) state = sdfg.add_state("dataflow") # Data layout is a 3-character string with either C (for row major) @@ -121,9 +121,14 @@ def _test_matmul(implementation, dtype, impl_name, storage, ref = np.dot(x, y) + if dtype == dace.float16 and np.linalg.norm(z) == 0: + print('No computation performed, half-precision probably not ' + 'supported, skipping test.') + return + diff = np.linalg.norm(ref - z) if diff >= eps: - print("Unexpected result returned from dot product: " + print("Unexpected result returned from matrix multiplication: " "diff %f" % diff) sys.exit(1) @@ -133,13 +138,13 @@ def _test_matmul(implementation, dtype, impl_name, storage, def test_types(): # Try different data types _test_matmul('cuBLAS double', dace.float64, 'cuBLAS', - dace.StorageType.GPU_Global) + dace.StorageType.GPU_Global) _test_matmul('cuBLAS half', dace.float16, 'cuBLAS', - dace.StorageType.GPU_Global, eps=1) + dace.StorageType.GPU_Global, eps=1) _test_matmul('cuBLAS scmplx', dace.complex64, 'cuBLAS', - dace.StorageType.GPU_Global) + dace.StorageType.GPU_Global, eps=1e-4) _test_matmul('cuBLAS dcmplx', dace.complex128, 'cuBLAS', - dace.StorageType.GPU_Global) + dace.StorageType.GPU_Global) def test_layouts(): # Try all data layouts From f4f468054a622b49866af115c8e84ae7338379b4 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 16 Mar 2020 19:43:10 +0100 Subject: [PATCH 17/32] Increase test timeout --- tests/cuda_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cuda_test.sh b/tests/cuda_test.sh index 5e2897da5a..2b7763fe8c 100755 --- a/tests/cuda_test.sh +++ b/tests/cuda_test.sh @@ -12,7 +12,7 @@ ERRORS=0 FAILED_TESTS="" TESTS=0 -TEST_TIMEOUT=60 +TEST_TIMEOUT=600 RED='\033[0;31m' NC='\033[0m' From 021cf3290030b1cc503e295ac2f3b1fa6a51169d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 17 Mar 2020 14:21:00 +0100 Subject: [PATCH 18/32] Minor update --- tests/library/matmul_cudatest.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/library/matmul_cudatest.py b/tests/library/matmul_cudatest.py index a9e2bc0e89..693811f430 100644 --- a/tests/library/matmul_cudatest.py +++ b/tests/library/matmul_cudatest.py @@ -103,7 +103,7 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, def _test_matmul(implementation, dtype, impl_name, storage, - data_layout='CCC', eps=1e-6): + data_layout='CCC', eps=1e-4): sdfg = make_sdfg(impl_name, dtype, storage, data_layout) csdfg = sdfg.compile(optimizer=False) @@ -138,13 +138,13 @@ def _test_matmul(implementation, dtype, impl_name, storage, def test_types(): # Try different data types _test_matmul('cuBLAS double', dace.float64, 'cuBLAS', - dace.StorageType.GPU_Global) + dace.StorageType.GPU_Global, eps=1e-6) _test_matmul('cuBLAS half', dace.float16, 'cuBLAS', dace.StorageType.GPU_Global, eps=1) _test_matmul('cuBLAS scmplx', dace.complex64, 'cuBLAS', - dace.StorageType.GPU_Global, eps=1e-4) - _test_matmul('cuBLAS dcmplx', dace.complex128, 'cuBLAS', dace.StorageType.GPU_Global) + _test_matmul('cuBLAS dcmplx', dace.complex128, 'cuBLAS', + dace.StorageType.GPU_Global, eps=1e-6) def test_layouts(): # Try all data layouts From ba49c246ce774e43ac32ce1258c591e5154d1eb2 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 11:05:39 +0100 Subject: [PATCH 19/32] Add strided-batched matrix multiplication support to CUBLAS MatMult --- dace/libraries/blas/nodes/matmul.py | 144 ++++++++++++++++++++-------- 1 file changed, 105 insertions(+), 39 deletions(-) diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index c1b32d12c2..695580d534 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -1,6 +1,7 @@ from copy import deepcopy as dc import numpy as np -from dace.config import Config +from typing import Any, Dict, Optional +from dace.data import Array import dace.library import dace.properties import dace.graph.nodes @@ -20,7 +21,7 @@ def _get_matmul_inputs(node, state, sdfg): size = subset.size() outer_array = sdfg.data( dace.sdfg.find_input_arraynode(state, edge).data) - res = edge, outer_array, (size[0], size[1]) + res = edge, outer_array, (size[-2], size[-1]) if edge.dst_conn == "_a": res_a = res else: @@ -31,6 +32,45 @@ def _get_matmul_inputs(node, state, sdfg): return res_a, res_b +def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: + """ + Detects whether a matrix multiplication is a batched matrix multiplication + and returns its parameters (strides, batch size), or an empty dictionary if + batched multiplication is not detected. + :param a: Data descriptor for the first tensor. + :param b: Data descriptor for the second tensor. + :param c: Data descriptor for the output tensor (optional). + :return: A dictionary with the following keys: sa,sb,sc (strides for a, b, + and c); and b (batch size). + """ + if len(a.shape) > 3 or len(b.shape) > 3 or (c and len(c.shape) > 3): + raise ValueError('Tensor dimensions too large for (batched) matrix ' + 'multiplication') + if len(a.shape) <= 2 and len(b.shape) <= 2: + return {} + + batch = None + stride_a, stride_b, stride_c = 0, 0, 0 + if len(a.shape) == 3: + batch = a.shape[0] + stride_a = a.strides[0] + if len(b.shape) == 3: + if batch and batch != b.shape[0]: + raise ValueError('Batch size mismatch for matrix multiplication') + batch = b.shape[0] + stride_b = b.strides[0] + if c and len(c.shape) == 3: + if batch and batch != c.shape[0]: + raise ValueError('Batch size mismatch for matrix multiplication') + batch = c.shape[0] + stride_c = c.strides[0] + + if batch is None: + return {} + + return {'sa': stride_a, 'sb': stride_b, 'sc': stride_c, 'b': batch} + + @dace.library.expansion class ExpandMatMulPure(ExpandTransformation): @@ -220,42 +260,57 @@ def expansion(node, state, sdfg): cdesc = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') - + + # Set up options for code formatting (_, _, (m, k)), (_, _, (_, n)) = _get_matmul_inputs(node, state, sdfg) opt = get_gemm_opts(adesc, bdesc, cdesc) + bopt = get_batchmm_opts(adesc, bdesc, cdesc) opt['x'] = '_a' opt['y'] = '_b' opt['M'] = m opt['N'] = n if opt['swap']: + if bopt: + bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa'] opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] opt['x'], opt['y'] = opt['y'], opt['x'] opt['ta'], opt['tb'] = opt['tb'], opt['ta'] opt['M'], opt['N'] = opt['N'], opt['M'] - - code = environments.cublas.cuBLAS.handle_setup_code(node) + ''' - cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, - {M}, {N}, {K}, - {alpha}, - ({dtype}*){x}, {lda}, - ({dtype}*){y}, {ldb}, - {beta}, - ({dtype}*)_c, {ldc}); - '''.format(M=opt['M'], - N=opt['N'], - K=k, - lda=opt['lda'], - ldb=opt['ldb'], - ldc=opt['ldc'], - x=opt['x'], - y=opt['y'], - ta=opt['ta'], - tb=opt['tb'], - alpha=alpha, - beta=beta, - dtype=cdtype, - func=func) + opt['K'] = k + opt['alpha'] = alpha + opt['beta'] = beta + opt['dtype'] = cdtype + opt['func'] = func + if bopt: + opt['stride_a'] = bopt['sa'] + opt['stride_b'] = bopt['sb'] + opt['stride_c'] = bopt['sc'] + opt['BATCH'] = bopt['b'] + + # Matrix multiplication + if not bopt: + call = '''cublas{func}(__dace_cublas_handle, + CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + {alpha}, + ({dtype}*){x}, {lda}, + ({dtype}*){y}, {ldb}, + {beta}, + ({dtype}*)_c, {ldc});''' + else: # Batched matrix multiplication + call = '''cublas{func}StridedBatched(__dace_cublas_handle, + CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, + {M}, {N}, {K}, + {alpha}, + ({dtype}*){x}, {lda}, {stride_a}, + ({dtype}*){y}, {ldb}, {stride_b}, + {beta}, + ({dtype}*)_c, {ldc}, {stride_c}, + {BATCH});''' + + code = (environments.cublas.cuBLAS.handle_setup_code(node) + + call.format_map(opt)) tasklet = dace.graph.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, @@ -268,8 +323,11 @@ def expansion(node, state, sdfg): class MatMul(dace.graph.nodes.LibraryNode): # Global properties - implementations = {"pure": ExpandMatMulPure, "MKL": ExpandMatMulMKL, - "cuBLAS": ExpandMatMulCuBLAS} + implementations = { + "pure": ExpandMatMulPure, + "MKL": ExpandMatMulMKL, + "cuBLAS": ExpandMatMulCuBLAS + } default_implementation = None # Object fields @@ -301,23 +359,31 @@ def validate(self, sdfg, state): raise ValueError( "Expected exactly one output from matrix-matrix product") out_memlet = out_edges[0].data - if len(size0) != 2: - raise ValueError( - "matrix-matrix product only supported on matrices") - if len(size1) != 2: - raise ValueError( - "matrix-matrix product only supported on matrices") - if size0[1] != size1[0]: + bopt = get_batchmm_opts(sdfg.arrays[in_edges[0].data.data], + sdfg.arrays[in_edges[1].data.data], + sdfg.arrays[out_edges[0].data.data]) + if not bopt: + if len(size0) != 2: + raise ValueError( + "matrix-matrix product only supported on matrices") + if len(size1) != 2: + raise ValueError( + "matrix-matrix product only supported on matrices") + if size0[-1] != size1[-2]: raise ValueError( "Inputs to matrix-matrix product must agree in the k-dimension" ) out_subset = dc(out_memlet.subset) out_subset.squeeze() size2 = out_subset.size() - if len(size2) != 2: + if len(size2) not in [2, 3]: raise ValueError( "matrix-matrix product only supported on matrices") - if list(size2) != [size0[0], size1[1]]: + if not bopt and list(size2) != [size0[-2], size1[-1]]: raise ValueError( - "Output to matrix-matrix product must agree in the m and n dimensions" - ) + "Output to matrix-matrix product must agree in the m and n " + "dimensions") + if bopt and list(size2) != [bopt['b'], size0[-2], size1[-1]]: + raise ValueError( + "Output to batch matrix-matrix product must agree in the b, " + "m, and n dimensions") From 0a49966650fb3832d2ccc0b16762fa17ba28c886 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 11:06:12 +0100 Subject: [PATCH 20/32] GPUTransformSDFG: Do not transform library nodes --- dace/transformation/interstate/gpu_transform_sdfg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index d960986c58..c64b517459 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -127,7 +127,8 @@ def apply(self, sdfg: sd.SDFG): and node.data not in output_nodes): output_nodes.append((node.data, node.desc(sdfg))) elif isinstance(node, nodes.CodeNode) and sdict[node] is None: - if not isinstance(node, nodes.EmptyTasklet): + if not isinstance(node, + (nodes.EmptyTasklet, nodes.LibraryNode)): global_code_nodes[i].append(node) # Input nodes may also be nodes with WCR memlets and no identity From 3948ce8e881915b9f2116e61b58095eba0627512 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 11:07:11 +0100 Subject: [PATCH 21/32] Python frontend: Support batched matrix multiplication with @ operator --- dace/frontend/python/newast.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 10b555f761..58c12d3a74 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -668,24 +668,36 @@ def _op(visitor: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, op1: str, @oprepo.replaces_operator('Array', 'MatMult') def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): - arr1 = sdfg.arrays[op1] arr2 = sdfg.arrays[op2] - if (len(arr1.shape) != 2 or len(arr2.shape) != 2 - or arr1.shape[1] != arr2.shape[0]): - raise SyntaxError('Matrix sizes must match') + # TODO: Apply numpy broadcast rules + if len(arr1.shape) > 3 or len(arr2.shape) > 3: + raise SyntaxError('Matrix multiplication of tensors of dimensions > 3 ' + 'not supported') + if arr1.shape[-1] != arr2.shape[-2]: + raise SyntaxError('Matrix dimension mismatch %s != %s' % + (arr1.shape[-1], arr2.shape[-2])) + + import dace.libraries.blas as blas # Avoid import loop + from dace.libraries.blas.nodes.matmul import get_batchmm_opts + + # Determine batched multiplication + bopt = get_batchmm_opts(arr1, arr2, None) + if bopt: + output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1]) + else: + output_shape = (arr1.shape[-2], arr2.shape[-1]) type1 = arr1.dtype.type type2 = arr2.dtype.type restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] - op3, arr3 = sdfg.add_temp_transient((arr1.shape[0], arr2.shape[1]), - restype, arr1.storage) + op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage) acc1 = state.add_read(op1) acc2 = state.add_read(op2) acc3 = state.add_write(op3) - import dace.libraries.blas as blas # Avoid import loop + tasklet = blas.MatMul('_MatMult_', restype) state.add_node(tasklet) state.add_edge(acc1, None, tasklet, '_a', @@ -1453,7 +1465,8 @@ def __init__( for stmt in _DISALLOWED_STMTS: setattr(self, 'visit_' + stmt, lambda n: _disallow_stmt(self, n)) - def parse_tasklet(self, tasklet_ast: TaskletType, + def parse_tasklet(self, + tasklet_ast: TaskletType, name: Optional[str] = None): """ Parses the AST of a tasklet and returns the tasklet node, as well as input and output memlets. :param tasklet_ast: The Tasklet's Python AST to parse. From 441c10d4feb8a0d564baf6bc76d8b56ceedff31a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 11:07:41 +0100 Subject: [PATCH 22/32] Batched matrix multiplication CUDA test --- tests/library/matmul_cudatest.py | 87 +++++++++++++++++++++++++------- 1 file changed, 68 insertions(+), 19 deletions(-) diff --git a/tests/library/matmul_cudatest.py b/tests/library/matmul_cudatest.py index 693811f430..524c0cad9c 100644 --- a/tests/library/matmul_cudatest.py +++ b/tests/library/matmul_cudatest.py @@ -10,7 +10,9 @@ ############################################################################### -def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, +def make_sdfg(implementation, + dtype, + storage=dace.StorageType.Default, data_layout='CCC'): m = dace.symbol("m") n = dace.symbol("n") @@ -28,7 +30,6 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, ystrides = (n, 1) if data_layout[1] == 'C' else (1, k) zstrides = (n, 1) if data_layout[2] == 'C' else (1, m) - sdfg.add_array("x" + suffix, [m, k], dtype, storage=storage, @@ -79,12 +80,10 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, y_device = init_state.add_write("y" + suffix) init_state.add_memlet_path(x_host, x_device, - memlet=Memlet.simple(x_host, - "0:m, 0:k")) + memlet=Memlet.simple(x_host, "0:m, 0:k")) init_state.add_memlet_path(y_host, y_device, - memlet=Memlet.simple(y_host, - "0:k, 0:n")) + memlet=Memlet.simple(y_host, "0:k, 0:n")) finalize_state = sdfg.add_state("copy_to_host") sdfg.add_edge(state, finalize_state, dace.InterstateEdge()) @@ -93,8 +92,8 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, result_host = finalize_state.add_read("result") finalize_state.add_memlet_path(result_device, result_host, - memlet=Memlet.simple(result_device, - "0:m, 0:n")) + memlet=Memlet.simple( + result_device, "0:m, 0:n")) return sdfg @@ -102,8 +101,12 @@ def make_sdfg(implementation, dtype, storage=dace.StorageType.Default, ############################################################################### -def _test_matmul(implementation, dtype, impl_name, storage, - data_layout='CCC', eps=1e-4): +def _test_matmul(implementation, + dtype, + impl_name, + storage, + data_layout='CCC', + eps=1e-4): sdfg = make_sdfg(impl_name, dtype, storage, data_layout) csdfg = sdfg.compile(optimizer=False) @@ -137,26 +140,72 @@ def _test_matmul(implementation, dtype, impl_name, storage, def test_types(): # Try different data types - _test_matmul('cuBLAS double', dace.float64, 'cuBLAS', - dace.StorageType.GPU_Global, eps=1e-6) - _test_matmul('cuBLAS half', dace.float16, 'cuBLAS', - dace.StorageType.GPU_Global, eps=1) + _test_matmul('cuBLAS double', + dace.float64, + 'cuBLAS', + dace.StorageType.GPU_Global, + eps=1e-6) + _test_matmul('cuBLAS half', + dace.float16, + 'cuBLAS', + dace.StorageType.GPU_Global, + eps=1) _test_matmul('cuBLAS scmplx', dace.complex64, 'cuBLAS', dace.StorageType.GPU_Global) - _test_matmul('cuBLAS dcmplx', dace.complex128, 'cuBLAS', - dace.StorageType.GPU_Global, eps=1e-6) + _test_matmul('cuBLAS dcmplx', + dace.complex128, + 'cuBLAS', + dace.StorageType.GPU_Global, + eps=1e-6) + def test_layouts(): # Try all data layouts - for dl in map(lambda t: ''.join(t), itertools.product(*([['C', 'F']]*3))): - _test_matmul('cuBLAS float ' + dl, dace.float32, 'cuBLAS', - dace.StorageType.GPU_Global, data_layout=dl) + for dl in map(lambda t: ''.join(t), + itertools.product(*([['C', 'F']] * 3))): + _test_matmul('cuBLAS float ' + dl, + dace.float32, + 'cuBLAS', + dace.StorageType.GPU_Global, + data_layout=dl) + + +def test_batchmm(): + b, m, n, k = tuple(dace.symbol(k) for k in 'bmnk') + + @dace.program + def bmmtest(A: dace.float64[b, m, k], B: dace.float64[b, k, n], + C: dace.float64[b, m, n]): + C[:] = A @ B + + sdfg = bmmtest.to_sdfg() + sdfg.apply_gpu_transformations() + for state in sdfg.nodes(): + for node in state.nodes(): + if isinstance(node, blas.nodes.matmul.MatMul): + node.implementation = 'cuBLAS' + csdfg = sdfg.compile(optimizer=False) + + b, m, n, k = 3, 32, 31, 30 + + x = np.random.rand(b, m, k) + y = np.random.rand(b, k, n) + z = np.zeros([b, m, n], np.float64) + csdfg(A=x, B=y, C=z, b=b, m=m, n=n, k=k) + + ref = x @ y + + diff = np.linalg.norm(ref - z) + print('Difference:', diff) + assert diff < 1e-6 + ############################################################################### if __name__ == '__main__': import os try: + test_batchmm() test_types() test_layouts() except SystemExit as ex: From 48a2532918b80ff615703dd0211490b1b67cb219 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 11:23:58 +0100 Subject: [PATCH 23/32] Matrix multiplication operator test --- tests/numpy/matrix_multiplication_test.py | 44 +++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 tests/numpy/matrix_multiplication_test.py diff --git a/tests/numpy/matrix_multiplication_test.py b/tests/numpy/matrix_multiplication_test.py new file mode 100644 index 0000000000..15d3db3519 --- /dev/null +++ b/tests/numpy/matrix_multiplication_test.py @@ -0,0 +1,44 @@ +import unittest +import dace +import numpy as np + +B, M, N, K = tuple(dace.symbol(k) for k in 'BMNK') + + +class MatrixMultiplication(unittest.TestCase): + def test_mmm(self): + @dace.program + def mmmtest(a: dace.float64[M, K], b: dace.float64[K, N]): + return a @ b + + a = np.random.rand(32, 33) + b = np.random.rand(33, 34) + c = mmmtest(a, b) + self.assertEqual(list(c.shape), [32, 34]) + self.assertTrue(np.allclose(c, a @ b)) + + def test_mmm_batch(self): + @dace.program + def mmmtest(a: dace.float64[B, M, K], b: dace.float64[B, K, N]): + return a @ b + + a = np.random.rand(3, 34, 32) + b = np.random.rand(3, 32, 31) + c = mmmtest(a, b) + self.assertEqual(list(c.shape), [3, 34, 31]) + self.assertTrue(np.allclose(c, a @ b)) + + def test_mmm_batch_stationary_a(self): + @dace.program + def mmmtest(a: dace.float64[M, K], b: dace.float64[B, K, N]): + return a @ b + + a = np.random.rand(34, 32) + b = np.random.rand(3, 32, 31) + c = mmmtest(a, b) + self.assertEqual(list(c.shape), [3, 34, 31]) + self.assertTrue(np.allclose(c, a @ b)) + + +if __name__ == '__main__': + unittest.main() From 49c57cedb6f4b506406a60060a423093a88e2b18 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 11:53:48 +0100 Subject: [PATCH 24/32] Pure implementation of batched matrix mutliplication --- dace/libraries/blas/nodes/matmul.py | 114 ++++++++++++++-------------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 695580d534..00994064cf 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -21,7 +21,7 @@ def _get_matmul_inputs(node, state, sdfg): size = subset.size() outer_array = sdfg.data( dace.sdfg.find_input_arraynode(state, edge).data) - res = edge, outer_array, (size[-2], size[-1]) + res = edge, outer_array, size if edge.dst_conn == "_a": res_a = res else: @@ -85,16 +85,19 @@ def make_sdfg(node, parent_state, parent_sdfg): ((edge_a, outer_array_a, shape_a), (edge_b, outer_array_b, shape_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg) + cdesc = parent_sdfg.arrays[parent_state.out_edges(node)[0].data.data] + bopt = get_batchmm_opts(outer_array_a, outer_array_b, cdesc) - if (len(shape_a) != 2 or len(shape_b) != 2 - or shape_a[1] != shape_b[0]): + if shape_a[-1] != shape_b[-2]: raise SyntaxError('Matrix sizes must match') - shape_c = (shape_a[0], shape_b[1]) + if bopt: + shape_c = (bopt['b'], shape_a[-2], shape_b[-1]) + else: + shape_c = (shape_a[-2], shape_b[-1]) dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type - dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, - dtype_b).type] + dtype_c = cdesc.dtype.type if outer_array_a.storage != outer_array_b.storage: raise ValueError("Input matrices must have same storage") @@ -104,23 +107,50 @@ def make_sdfg(node, parent_state, parent_sdfg): _, array_b = sdfg.add_array("_b", shape_b, dtype_b, storage=storage) _, array_c = sdfg.add_array("_c", shape_c, dtype_c, storage=storage) - state.add_mapped_tasklet( - '_MatMult_', { + if not bopt: + state.add_mapped_tasklet('_MatMult_', { '__i%d' % i: '0:%s' % s for i, s in enumerate( - [array_a.shape[0], array_b.shape[1], array_a.shape[1]]) + [array_a.shape[-2], array_b.shape[-1], array_a.shape[-1]]) }, { - '__a': dace.Memlet.simple("_a", '__i0, __i2'), - '__b': dace.Memlet.simple("_b", '__i2, __i1') + '__a': + dace.Memlet.simple("_a", '__i0, __i2'), + '__b': + dace.Memlet.simple("_b", '__i2, __i1') }, - '__c = __a * __b', { - '__c': - dace.Memlet.simple("_c", - '__i0, __i1', - wcr_str='lambda x, y: x + y', - wcr_identity=0) - }, - external_edges=True) + '__c = __a * __b', { + '__c': + dace.Memlet.simple( + "_c", + '__i0, __i1', + wcr_str='lambda x, y: x + y', + wcr_identity=0) + }, + external_edges=True) + else: # Batched matrix multiplication + state.add_mapped_tasklet( + '_BatchedMatMult_', { + '__i%d' % i: '0:%s' % s + for i, s in enumerate([ + bopt['b'], array_a.shape[-2], array_b.shape[-1], + array_a.shape[-1] + ]) + }, { + '__a': + dace.Memlet.simple("_a", ('__i1, __i3' if len( + array_a.shape) == 2 else '__i0, __i1, __i3')), + '__b': + dace.Memlet.simple("_b", ('__i3, __i2' if len( + array_b.shape) == 2 else '__i0, __i3, __i2')) + }, + '__c = __a * __b', { + '__c': + dace.Memlet.simple("_c", + '__i0, __i1, __i2', + wcr_str='lambda x, y: x + y', + wcr_identity=0) + }, + external_edges=True) sdfg.parent = parent_sdfg sdfg.parent_sdfg = parent_sdfg @@ -161,44 +191,12 @@ def expansion(node, state, sdfg): else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) - (_, _, (m, k)), (_, _, (_, n)) = _get_matmul_inputs(node, state, sdfg) - code = ("cblas_{f}(CblasRowMajor, CblasNoTrans, CblasNoTrans, " - "{m}, {n}, {k}, {a}, _a, {k}, _b, {n}, {b}, _c, {n});").format( - f=func, m=m, n=n, k=k, a=alpha, b=beta) - tasklet = dace.graph.nodes.Tasklet(node.name, - node.in_connectors, - node.out_connectors, - code, - language=dace.dtypes.Language.CPP) - return tasklet - - -@dace.library.expansion -class ExpandMatMulMKL(ExpandTransformation): - - environments = [environments.intel_mkl.IntelMKL] - - @staticmethod - def expansion(node, state, sdfg): - node.validate(sdfg, state) - dtype = node.dtype - func = to_blastype(dtype.type).lower() + 'gemm' - if dtype == dace.float32: - alpha = "1.0f" - beta = "0.0f" - elif dtype == dace.float64: - alpha = "1.0" - beta = "0.0" - elif dtype == dace.complex64: - alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" - beta = "dace::blas::BlasConstants::Get().Complex64Zero()" - elif dtype == dace.complex128: - alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" - beta = "dace::blas::BlasConstants::Get().Complex128Zero()" - else: - raise ValueError("Unsupported type for BLAS dot product: " + - str(dtype)) - (_, _, (m, k)), (_, _, (_, n)) = _get_matmul_inputs(node, state, sdfg) + (_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg) + m, k = ashape[-2:] + n = bshape[-1] + # TODO: Use strides instead of shape + # TODO: Use gemm opts + # TODO: Use batch gemm opts code = ("cblas_{f}(CblasRowMajor, CblasNoTrans, CblasNoTrans, " "{m}, {n}, {k}, {a}, _a, {k}, _b, {n}, {b}, _c, {n});").format( f=func, m=m, n=n, k=k, a=alpha, b=beta) @@ -217,7 +215,6 @@ class ExpandMatMulCuBLAS(ExpandTransformation): @staticmethod def expansion(node, state, sdfg): - gpuid = node.location or '0' node.validate(sdfg, state) dtype = node.dtype func = '%sgemm' % to_blastype(dtype.type) @@ -359,6 +356,7 @@ def validate(self, sdfg, state): raise ValueError( "Expected exactly one output from matrix-matrix product") out_memlet = out_edges[0].data + # Function is symmetric, edge order does not matter bopt = get_batchmm_opts(sdfg.arrays[in_edges[0].data.data], sdfg.arrays[in_edges[1].data.data], sdfg.arrays[out_edges[0].data.data]) From 6565cfe3baa2054d16dbf6a34cf3a6163749a045 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 14:09:06 +0100 Subject: [PATCH 25/32] MKL version of matrix multiplication node --- dace/libraries/blas/nodes/matmul.py | 97 ++++++++++++++++++----------- 1 file changed, 62 insertions(+), 35 deletions(-) diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 00994064cf..7b069244ff 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -71,6 +71,41 @@ def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: return {'sa': stride_a, 'sb': stride_b, 'sc': stride_c, 'b': batch} +def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, + cdtype, func) -> Dict[str, Any]: + """ Get option map for GEMM code generation (with column-major order). """ + (_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg) + opt = get_gemm_opts(adesc, bdesc, cdesc) + bopt = get_batchmm_opts(adesc, bdesc, cdesc) + opt['x'] = '_a' + opt['y'] = '_b' + opt['M'] = ashape[-2] + opt['N'] = bshape[-1] + opt['K'] = ashape[-1] + + if opt['swap']: + if bopt: + bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa'] + opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] + opt['x'], opt['y'] = opt['y'], opt['x'] + opt['ta'], opt['tb'] = opt['tb'], opt['ta'] + opt['M'], opt['N'] = opt['N'], opt['M'] + + opt['alpha'] = alpha + opt['beta'] = beta + opt['dtype'] = cdtype + opt['func'] = func + if bopt: + opt['stride_a'] = bopt['sa'] + opt['stride_b'] = bopt['sb'] + opt['stride_c'] = bopt['sc'] + opt['BATCH'] = bopt['b'] + else: + opt['BATCH'] = None + + return opt + + @dace.library.expansion class ExpandMatMulPure(ExpandTransformation): @@ -191,15 +226,30 @@ def expansion(node, state, sdfg): else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) - (_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg) - m, k = ashape[-2:] - n = bshape[-1] - # TODO: Use strides instead of shape - # TODO: Use gemm opts - # TODO: Use batch gemm opts - code = ("cblas_{f}(CblasRowMajor, CblasNoTrans, CblasNoTrans, " - "{m}, {n}, {k}, {a}, _a, {k}, _b, {n}, {b}, _c, {n});").format( - f=func, m=m, n=n, k=k, a=alpha, b=beta) + (_, adesc, ashape), (_, bdesc, + bshape) = _get_matmul_inputs(node, state, sdfg) + cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] + opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, + alpha, beta, cdesc.dtype.ctype, func) + + # Adaptations for MKL/BLAS API + opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' + opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' + + if not opt['BATCH']: + code = ("cblas_{func}(CblasColMajor, {ta}, {tb}, " + "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " + "_c, {ldc});").format_map(opt) + else: + code = ''' + for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ + cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha}, + (({dtype}*){x}) + __ib*{stride_a}, {lda}, + (({dtype}*){y}) + __ib*{stride_b}, {ldb}, + {beta}, + (({dtype}*)_c) + __ib*{stride_c}, {ldc}); + }}'''.format_map(opt) + tasklet = dace.graph.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, @@ -259,34 +309,11 @@ def expansion(node, state, sdfg): raise ValueError('Unsupported input/output arrays') # Set up options for code formatting - (_, _, (m, k)), (_, _, (_, n)) = _get_matmul_inputs(node, state, sdfg) - opt = get_gemm_opts(adesc, bdesc, cdesc) - bopt = get_batchmm_opts(adesc, bdesc, cdesc) - opt['x'] = '_a' - opt['y'] = '_b' - opt['M'] = m - opt['N'] = n - if opt['swap']: - if bopt: - bopt['sa'], bopt['sb'] = bopt['sb'], bopt['sa'] - opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] - opt['x'], opt['y'] = opt['y'], opt['x'] - opt['ta'], opt['tb'] = opt['tb'], opt['ta'] - opt['M'], opt['N'] = opt['N'], opt['M'] - - opt['K'] = k - opt['alpha'] = alpha - opt['beta'] = beta - opt['dtype'] = cdtype - opt['func'] = func - if bopt: - opt['stride_a'] = bopt['sa'] - opt['stride_b'] = bopt['sb'] - opt['stride_c'] = bopt['sc'] - opt['BATCH'] = bopt['b'] + opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, + alpha, beta, cdtype, func) # Matrix multiplication - if not bopt: + if not opt['BATCH']: call = '''cublas{func}(__dace_cublas_handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, {M}, {N}, {K}, From d742adf7a1184399e5337aa4b0e80dc6a396537a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 15:39:41 +0100 Subject: [PATCH 26/32] MatMult: Fix code generation symbol output --- dace/libraries/blas/nodes/matmul.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 7b069244ff..d44ce9aa44 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -1,5 +1,4 @@ from copy import deepcopy as dc -import numpy as np from typing import Any, Dict, Optional from dace.data import Array import dace.library @@ -74,14 +73,20 @@ def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdtype, func) -> Dict[str, Any]: """ Get option map for GEMM code generation (with column-major order). """ + # Avoid import loops + from dace.codegen.targets.common import sym2cpp + (_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg) opt = get_gemm_opts(adesc, bdesc, cdesc) bopt = get_batchmm_opts(adesc, bdesc, cdesc) opt['x'] = '_a' opt['y'] = '_b' - opt['M'] = ashape[-2] - opt['N'] = bshape[-1] - opt['K'] = ashape[-1] + opt['M'] = sym2cpp(ashape[-2]) + opt['N'] = sym2cpp(bshape[-1]) + opt['K'] = sym2cpp(ashape[-1]) + opt['lda'] = sym2cpp(opt['lda']) + opt['ldb'] = sym2cpp(opt['ldb']) + opt['ldc'] = sym2cpp(opt['ldc']) if opt['swap']: if bopt: @@ -96,10 +101,10 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, opt['dtype'] = cdtype opt['func'] = func if bopt: - opt['stride_a'] = bopt['sa'] - opt['stride_b'] = bopt['sb'] - opt['stride_c'] = bopt['sc'] - opt['BATCH'] = bopt['b'] + opt['stride_a'] = sym2cpp(bopt['sa']) + opt['stride_b'] = sym2cpp(bopt['sb']) + opt['stride_c'] = sym2cpp(bopt['sc']) + opt['BATCH'] = sym2cpp(bopt['b']) else: opt['BATCH'] = None @@ -295,16 +300,16 @@ def expansion(node, state, sdfg): if e.dst_conn == '_a': anode = state.memlet_path(e)[0].src if isinstance(anode, dace.graph.nodes.AccessNode): - adesc = sdfg.arrays[anode.data] + adesc: Array = sdfg.arrays[anode.data] elif e.dst_conn == '_b': bnode = state.memlet_path(e)[0].src if isinstance(bnode, dace.graph.nodes.AccessNode): - bdesc = sdfg.arrays[bnode.data] + bdesc: Array = sdfg.arrays[bnode.data] for e in state.out_edges(node): if e.src_conn == '_c': cnode = state.memlet_path(e)[-1].dst if isinstance(cnode, dace.graph.nodes.AccessNode): - cdesc = sdfg.arrays[cnode.data] + cdesc: Array = sdfg.arrays[cnode.data] if not adesc or not bdesc or not cdesc: raise ValueError('Unsupported input/output arrays') From 4cc2836719b91bd2fd3c79aac5aef73194d1cc00 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 15:40:14 +0100 Subject: [PATCH 27/32] CUBLAS MatMult: Copy data to GPU if not already there --- dace/libraries/blas/nodes/matmul.py | 36 +++++++++++++++++++++++++ dace/transformation/pattern_matching.py | 2 +- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index d44ce9aa44..6780adc45c 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -345,6 +345,42 @@ def expansion(node, state, sdfg): node.out_connectors, code, language=dace.dtypes.Language.CPP) + + # If buffers are not on the GPU, copy them + # TODO: This creates variable shadowing + if any(desc.storage not in + [dace.StorageType.GPU_Global, dace.StorageType.CPU_Pinned] + for desc in [adesc, bdesc, cdesc]): + nsdfg = dace.SDFG('nested_matmul') + for name, desc in [('_a', adesc), ('_b', bdesc), ('_c', cdesc)]: + dcopy = dc(desc) + dcopy.transient = False + nsdfg.add_datadesc(name, dcopy) + dcopy_gpu = dc(desc) + dcopy_gpu.transient = True + dcopy_gpu.storage = dace.StorageType.GPU_Global + nsdfg.add_datadesc(name + '_gpu', dcopy_gpu) + nstate = nsdfg.add_state() + a = nstate.add_read('_a') + ga = nstate.add_access('_a_gpu') + b = nstate.add_read('_b') + gb = nstate.add_access('_b_gpu') + c = nstate.add_write('_c') + gc = nstate.add_access('_c_gpu') + nstate.add_node(tasklet) + nstate.add_nedge(a, ga, dace.Memlet.from_array('_a', adesc)) + nstate.add_nedge(b, gb, dace.Memlet.from_array('_b', bdesc)) + nstate.add_edge(ga, None, tasklet, '_a', + dace.Memlet.from_array('_a_gpu', adesc)) + nstate.add_edge(gb, None, tasklet, '_b', + dace.Memlet.from_array('_b_gpu', bdesc)) + nstate.add_edge(tasklet, '_c', gc, None, + dace.Memlet.from_array('_c_gpu', cdesc)) + nstate.add_nedge(gc, c, dace.Memlet.from_array('_c', cdesc)) + + return nsdfg + # End of copy to GPU + return tasklet diff --git a/dace/transformation/pattern_matching.py b/dace/transformation/pattern_matching.py index e6ef25dbcb..2a1828c920 100644 --- a/dace/transformation/pattern_matching.py +++ b/dace/transformation/pattern_matching.py @@ -9,7 +9,7 @@ from dace.sdfg import SDFG, SDFGState from dace.properties import make_properties, Property, SubgraphProperty from dace.registry import make_registry -from dace.graph import labeling, graph as gr, nodes as nd +from dace.graph import labeling, graph as gr, nodes as nd, nxutil import networkx as nx from networkx.algorithms import isomorphism as iso from typing import Dict, List, Tuple, Type, Union From 760019672163aa6c328dc72a9aaaa71b409c39ac Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 15:41:36 +0100 Subject: [PATCH 28/32] Einsum: Use library nodes when possible --- dace/frontend/common/einsum.py | 199 ++++++++++----------------------- 1 file changed, 56 insertions(+), 143 deletions(-) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index d2c88982cd..cd0dbaab8e 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -109,147 +109,48 @@ def __repr__(self): return str(self) -# TODO: Remove once library nodes are used -cublas_initialized = False - - def create_batch_gemm_sdfg(dtype, strides): - # TODO: Use MatMult library node ######################### sdfg = SDFG('einsum') state = sdfg.add_state() BATCH, M, K, N, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( - #symbolic.symbol(s) for s in [ - strides[s] for s in [ + symbolic.symbol(s) for s in [ 'BATCH', 'M', 'K', 'N', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN', 'sCB' ]) batched = strides['BATCH'] != 1 - _, xarr = sdfg.add_array('X', - dtype=dtype, - shape=[BATCH, M, K] if batched else [M, K], - strides=[sAB, sAM, sAK] if batched else [ - sAM, sAK], - storage=dtypes.StorageType.GPU_Global) - _, yarr = sdfg.add_array('Y', - dtype=dtype, - shape=[BATCH, K, N] if batched else [K, N], - strides=[sBB, sBK, sBN] if batched else [ - sBK, sBN], - storage=dtypes.StorageType.GPU_Global) - _, zarr = sdfg.add_array('Z', - dtype=dtype, - shape=[BATCH, M, N] if batched else [M, N], - strides=[sCB, sCM, sCN] if batched else [ - sCM, sCN], - storage=dtypes.StorageType.GPU_Global) + _, xarr = sdfg.add_array( + 'X', + dtype=dtype, + shape=[BATCH, M, K] if batched else [M, K], + strides=[sAB, sAM, sAK] if batched else [sAM, sAK], + storage=dtypes.StorageType.GPU_Global) + _, yarr = sdfg.add_array( + 'Y', + dtype=dtype, + shape=[BATCH, K, N] if batched else [K, N], + strides=[sBB, sBK, sBN] if batched else [sBK, sBN], + storage=dtypes.StorageType.GPU_Global) + _, zarr = sdfg.add_array( + 'Z', + dtype=dtype, + shape=[BATCH, M, N] if batched else [M, N], + strides=[sCB, sCM, sCN] if batched else [sCM, sCN], + storage=dtypes.StorageType.GPU_Global) gX = state.add_read('X') gY = state.add_read('Y') - gZ = state.add_access('Z') - - opt = get_gemm_opts(xarr, yarr, zarr) - - opt['sta'] = sAB - opt['stb'] = sBB - opt['stc'] = sCB - opt['x'] = 'x' - opt['y'] = 'y' - opt['M'] = M - opt['N'] = N - if opt['swap']: - opt['lda'], opt['ldb'] = opt['ldb'], opt['lda'] - opt['sta'], opt['stb'] = opt['stb'], opt['sta'] - opt['x'], opt['y'] = opt['y'], opt['x'] - opt['ta'], opt['tb'] = opt['tb'], opt['ta'] - opt['M'], opt['N'] = opt['N'], opt['M'] - - global cublas_initialized - if not cublas_initialized: - code_global = ''' - #include - cublasHandle_t handle; - ''' - code_init = 'cublasCreate(&handle);' - code_exit = 'cublasDestroy(handle);' - cublas_initialized = True - else: - code_global = '' - code_init = '' - code_exit = '' - - cublas_gemm = 'cublas%sgemm' % to_blastype(dtype.type) - - if not batched: - code = ''' - cublasSetStream(handle, __dace_current_stream); - {c_dtype} alpha_unused = 1.0, beta_unused = 0.0; - {cublas_gemm}(handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, - {M}, {N}, {K}, - &alpha_unused, - {x}, {lda}, - {y}, {ldb}, - &beta_unused, - z, {ldc}); - '''.format(BATCH=BATCH, - M=opt['M'], - N=opt['N'], - K=K, - lda=opt['lda'], - ldb=opt['ldb'], - ldc=opt['ldc'], - x=opt['x'], - y=opt['y'], - ta=opt['ta'], - tb=opt['tb'], - c_dtype=dtype.ctype, - cublas_gemm=cublas_gemm) - else: - code = ''' - cublasSetStream(handle, __dace_current_stream); - {c_dtype} alpha_unused = 1.0, beta_unused = 0.0; - {cublas_gemm}StridedBatched(handle, CUBLAS_OP_{ta}, CUBLAS_OP_{tb}, - {M}, {N}, {K}, - &alpha_unused, - {x}, {lda}, {stride_a}, - {y}, {ldb}, {stride_b}, - &beta_unused, - z, {ldc}, {stride_c}, - {BATCH}); - '''.format(BATCH=BATCH, - M=opt['M'], - N=opt['N'], - K=K, - lda=opt['lda'], - ldb=opt['ldb'], - ldc=opt['ldc'], - stride_a=opt['sta'], - stride_b=opt['stb'], - stride_c=opt['stc'], - x=opt['x'], - y=opt['y'], - ta=opt['ta'], - tb=opt['tb'], - c_dtype=dtype.ctype, - cublas_gemm=cublas_gemm) - - cublas_tasklet = state.add_tasklet(name="cublas_tasklet", - inputs={'x', 'y'}, - outputs={'z'}, - code=code, - code_global=code_global, - code_init=code_init, - code_exit=code_exit, - language=dtypes.Language.CPP) - - state.add_edge(gX, None, cublas_tasklet, 'x', - Memlet.from_array(gX, gX.desc(sdfg))) - state.add_edge(gY, None, cublas_tasklet, 'y', - Memlet.from_array(gY, gY.desc(sdfg))) - state.add_edge(cublas_tasklet, 'z', gZ, None, - Memlet.from_array(gZ, gZ.desc(sdfg))) + gZ = state.add_write('Z') + + import dace.libraries.blas as blas # Avoid import loop + + libnode = blas.MatMul('einsum_gemm', zarr.dtype) + state.add_node(libnode) + state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr)) + state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr)) + state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr)) return sdfg @@ -266,8 +167,12 @@ def create_einsum_sdfg(sdfg: SDFG, dtype: Optional[dtypes.typeclass] = None, optimize: bool = False, output: Optional[str] = None): - return _create_einsum_internal(sdfg, state, einsum_string, *arrays, - dtype=dtype, optimize=optimize, + return _create_einsum_internal(sdfg, + state, + einsum_string, + *arrays, + dtype=dtype, + optimize=optimize, output=output)[0] @@ -320,9 +225,15 @@ def _create_einsum_internal(sdfg: SDFG, # Follow path and create a chain of operation SDFG states for pair, nonfree, expr, after, blas in path_info.contraction_list: - result, result_node = _create_einsum_internal( - sdfg, state, expr, arrays[pair[0]], arrays[pair[1]], - dtype=dtype, optimize=False, output=None, nodes=input_nodes) + result, result_node = _create_einsum_internal(sdfg, + state, + expr, + arrays[pair[0]], + arrays[pair[1]], + dtype=dtype, + optimize=False, + output=None, + nodes=input_nodes) arrays = ([a for i, a in enumerate(arrays) if i not in pair] + [result]) input_nodes[result] = result_node @@ -355,12 +266,13 @@ def _create_einsum_internal(sdfg: SDFG, if len(einsum.output) > 0: init_state.add_mapped_tasklet( 'einsum_reset', - {k: '0:%s' % chardict[k] for k in einsum.output}, - {}, 'out_%s = 0' % output, + {k: '0:%s' % chardict[k] + for k in einsum.output}, {}, + 'out_%s = 0' % output, {'out_%s' % output: Memlet.simple(output, output_index)}, external_edges=True) else: # Scalar output - t = init_state.add_tasklet('einsum_reset', {}, + t = init_state.add_tasklet('einsum_reset', set(), {'out_%s' % output}, 'out_%s = 0' % output) onode = init_state.add_write(output) @@ -370,20 +282,21 @@ def _create_einsum_internal(sdfg: SDFG, # Pure einsum map state.add_mapped_tasklet( 'einsum', {k: '0:%s' % v - for k, v in chardict.items()}, + for k, v in chardict.items()}, { + 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) + for inp, arr in zip(einsum.inputs, arrays) + }, + 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr + for arr in arrays)), { - 'inp_%s' % arr: Memlet.simple(arr, ','.join(inp)) - for inp, arr in zip(einsum.inputs, arrays) + 'out_%s' % output: + Memlet.simple(output, output_index, wcr_str='lambda a,b: a+b') }, - 'out_%s = %s' % (output, ' * '.join('inp_%s' % arr for arr in arrays)), - {'out_%s' % output: Memlet.simple(output, output_index, - wcr_str='lambda a,b: a+b')}, input_nodes=input_nodes, output_nodes={output: c}, external_edges=True) else: - # TODO: Only CUDA is supported - # Represent einsum as a GEMM or batched GEMM + # Represent einsum as a GEMM or batched GEMM (using library nodes) a_shape = sdfg.arrays[arrays[0]].shape b_shape = sdfg.arrays[arrays[1]].shape c_shape = output_shape From 5bc63812bbe7fa983827694135243d250705b816 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 18:23:49 +0100 Subject: [PATCH 29/32] Adapt einsum replacement to work with dot products --- dace/frontend/common/einsum.py | 22 ++++++++++++++++++---- dace/libraries/blas/blas_helpers.py | 7 ++++--- dace/libraries/blas/nodes/matmul.py | 8 ++++---- tests/numpy/einsum_test.py | 1 + 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index cd0dbaab8e..a4e57f0d54 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -113,10 +113,12 @@ def create_batch_gemm_sdfg(dtype, strides): ######################### sdfg = SDFG('einsum') state = sdfg.add_state() - BATCH, M, K, N, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( - symbolic.symbol(s) for s in [ - 'BATCH', 'M', 'K', 'N', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', - 'sCM', 'sCN', 'sCB' + M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N']) + BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( + symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s] + for s in [ + 'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN', + 'sCB' ]) batched = strides['BATCH'] != 1 @@ -326,6 +328,18 @@ def _create_einsum_internal(sdfg: SDFG, sCB=prod(c_shape[einsum.c_batch[-1] + 1:]) if einsum.c_batch else 1) + # Complement strides to make matrices as necessary + if len(a_shape) == 1 and len(einsum.a_sum) == 1: + strides['sAK'] = 1 + strides['sAB'] = strides['sAM'] = strides['K'] + if len(b_shape) == 1 and len(einsum.b_sum) == 1: + strides['sBN'] = 1 + strides['sBK'] = 1 + strides['sBB'] = strides['K'] + if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum): + strides['sCN'] = 1 + strides['sCB'] = strides['sCM'] = strides['N'] + # Create nested SDFG for GEMM nsdfg = create_batch_gemm_sdfg(dtype, strides) diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py index ad749dc2a0..49e5bba32d 100644 --- a/dace/libraries/blas/blas_helpers.py +++ b/dace/libraries/blas/blas_helpers.py @@ -2,6 +2,7 @@ from dace.data import Array from typing import Any, Dict + def to_blastype(dtype): """ Returns a BLAS character that corresponds to the input type. Used in MKL/CUBLAS calls. """ @@ -125,10 +126,10 @@ def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]: else: raise Exception("sAM or sAK should be 1") - if sBK == 1: - optB = 'k' - elif sBN == 1: + if sBN == 1: optB = 'n' + elif sBK == 1: + optB = 'k' else: raise Exception("sBK or sBN should be 1") diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 6780adc45c..3a322b358e 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -16,7 +16,7 @@ def _get_matmul_inputs(node, state, sdfg): for edge in state.in_edges(node): if edge.dst_conn in ["_a", "_b"]: subset = dc(edge.data.subset) - subset.squeeze() + #subset.squeeze() size = subset.size() outer_array = sdfg.data( dace.sdfg.find_input_arraynode(state, edge).data) @@ -413,11 +413,11 @@ def validate(self, sdfg, state): for _, _, _, dst_conn, memlet in state.in_edges(self): if dst_conn == '_a': subset = dc(memlet.subset) - subset.squeeze() + #subset.squeeze() size0 = subset.size() if dst_conn == '_b': subset = dc(memlet.subset) - subset.squeeze() + #subset.squeeze() size1 = subset.size() out_edges = state.out_edges(self) if len(out_edges) != 1: @@ -440,7 +440,7 @@ def validate(self, sdfg, state): "Inputs to matrix-matrix product must agree in the k-dimension" ) out_subset = dc(out_memlet.subset) - out_subset.squeeze() + #out_subset.squeeze() size2 = out_subset.size() if len(size2) not in [2, 3]: raise ValueError( diff --git a/tests/numpy/einsum_test.py b/tests/numpy/einsum_test.py index 099b937003..eb87ba878d 100644 --- a/tests/numpy/einsum_test.py +++ b/tests/numpy/einsum_test.py @@ -74,4 +74,5 @@ def einsumtest(A: dace.float64[N, N, N, N], B: dace.float64[N, N, N, N], test_general_einsum() test_matmul() test_batch_matmul() + test_opteinsum_sym() test_opteinsum() From bc108f27310893c795537e0a3bdf66e8d4ca9307 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 18:25:01 +0100 Subject: [PATCH 30/32] InlineSDFG: Don't inline in strict mode if more dimensions are inside (reshape) --- .../transformation/interstate/sdfg_nesting.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index b4054db631..11accd0e04 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -3,7 +3,7 @@ from copy import deepcopy as dc import itertools import networkx as nx -from typing import Dict, List, Set +from typing import Dict, List, Set, Optional from dace import memlet, registry, sdfg as sd, Memlet, EmptyMemlet from dace.graph import nodes, nxutil @@ -46,6 +46,18 @@ def expressions(): # Matches anything return [nxutil.node_path_graph(InlineSDFG._nested_sdfg)] + @staticmethod + def _find_edge(state: SDFGState, node: nodes.Node, + connector: str) -> Optional[MultiConnectorEdge]: + for edge in state.in_edges(node): + if edge.dst_conn == connector: + return edge + for edge in state.out_edges(node): + if edge.src_conn == connector: + return edge + raise NameError('Edge with connector %s not found on node %s' % + (connector, node)) + @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): nested_sdfg = graph.nodes()[candidate[InlineSDFG._nested_sdfg]] @@ -82,6 +94,16 @@ def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): if isinstance(e.dst, nodes.AccessNode))): return False + # If some reshaping that cannot be inlined / unsqueezed is happening, + # do not match transformation in strict mode. + if strict: + for aname, array in nested_sdfg.sdfg.arrays.items(): + if array.transient: + continue + edge = InlineSDFG._find_edge(graph, nested_sdfg, aname) + if len(array.shape) > len(edge.data.subset): + return False + return True @staticmethod From 1f3784581cbac67f786a06c411cd5188c3a5bdbd Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 22 Mar 2020 18:25:18 +0100 Subject: [PATCH 31/32] Remove GPU storage in einsum --- dace/frontend/common/einsum.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index a4e57f0d54..3956a2e0de 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -127,20 +127,17 @@ def create_batch_gemm_sdfg(dtype, strides): 'X', dtype=dtype, shape=[BATCH, M, K] if batched else [M, K], - strides=[sAB, sAM, sAK] if batched else [sAM, sAK], - storage=dtypes.StorageType.GPU_Global) + strides=[sAB, sAM, sAK] if batched else [sAM, sAK]) _, yarr = sdfg.add_array( 'Y', dtype=dtype, shape=[BATCH, K, N] if batched else [K, N], - strides=[sBB, sBK, sBN] if batched else [sBK, sBN], - storage=dtypes.StorageType.GPU_Global) + strides=[sBB, sBK, sBN] if batched else [sBK, sBN]) _, zarr = sdfg.add_array( 'Z', dtype=dtype, shape=[BATCH, M, N] if batched else [M, N], - strides=[sCB, sCM, sCN] if batched else [sCM, sCN], - storage=dtypes.StorageType.GPU_Global) + strides=[sCB, sCM, sCN] if batched else [sCM, sCN]) gX = state.add_read('X') gY = state.add_read('Y') From e8bbd38c2760b6af2be5430c90a18b5eef9a9d8c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 23 Mar 2020 08:51:13 +0100 Subject: [PATCH 32/32] Matmult: squeeze, use memlets instead of arrays --- dace/frontend/common/einsum.py | 1 - dace/frontend/python/newast.py | 3 +- dace/libraries/blas/blas_helpers.py | 8 +-- dace/libraries/blas/nodes/matmul.py | 90 +++++++++++++++-------------- 4 files changed, 53 insertions(+), 49 deletions(-) diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index 3956a2e0de..c08b424c96 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -10,7 +10,6 @@ from dace.sdfg import SDFG, SDFGState from dace.memlet import Memlet from dace.frontend.common import op_repository as oprepo -from dace.libraries.blas.blas_helpers import to_blastype, get_gemm_opts def _is_sequential(index_list): diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 58c12d3a74..7ce01a903e 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -682,7 +682,8 @@ def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): from dace.libraries.blas.nodes.matmul import get_batchmm_opts # Determine batched multiplication - bopt = get_batchmm_opts(arr1, arr2, None) + bopt = get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, + None, None) if bopt: output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1]) else: diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py index 49e5bba32d..d3f9dd9c92 100644 --- a/dace/libraries/blas/blas_helpers.py +++ b/dace/libraries/blas/blas_helpers.py @@ -22,7 +22,7 @@ def to_blastype(dtype): dtype.__name__) -def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]: +def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]: """ Returns GEMM argument order, transposition, and leading dimensions based on column-major storage from dace arrays. @@ -48,9 +48,9 @@ def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]: # | | | # use these 3 to detect correct option - sAM, sAK = a.strides[-2:] - sBK, sBN = b.strides[-2:] - sCM, sCN = c.strides[-2:] + sAM, sAK = a_strides[-2:] + sBK, sBN = b_strides[-2:] + sCM, sCN = c_strides[-2:] opts = { 'mkm': { diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 3a322b358e..687537486f 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -16,11 +16,14 @@ def _get_matmul_inputs(node, state, sdfg): for edge in state.in_edges(node): if edge.dst_conn in ["_a", "_b"]: subset = dc(edge.data.subset) - #subset.squeeze() + squeezed = subset.squeeze() size = subset.size() outer_array = sdfg.data( dace.sdfg.find_input_arraynode(state, edge).data) - res = edge, outer_array, size + strides = [ + s for i, s in enumerate(outer_array.strides) if i in squeezed + ] + res = edge, outer_array, size, strides if edge.dst_conn == "_a": res_a = res else: @@ -31,7 +34,8 @@ def _get_matmul_inputs(node, state, sdfg): return res_a, res_b -def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: +def get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, + c_strides) -> Dict[str, Any]: """ Detects whether a matrix multiplication is a batched matrix multiplication and returns its parameters (strides, batch size), or an empty dictionary if @@ -42,27 +46,27 @@ def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: :return: A dictionary with the following keys: sa,sb,sc (strides for a, b, and c); and b (batch size). """ - if len(a.shape) > 3 or len(b.shape) > 3 or (c and len(c.shape) > 3): + if len(a_shape) > 3 or len(b_shape) > 3 or (c_shape and len(c_shape) > 3): raise ValueError('Tensor dimensions too large for (batched) matrix ' 'multiplication') - if len(a.shape) <= 2 and len(b.shape) <= 2: + if len(a_shape) <= 2 and len(b_shape) <= 2: return {} batch = None stride_a, stride_b, stride_c = 0, 0, 0 - if len(a.shape) == 3: - batch = a.shape[0] - stride_a = a.strides[0] - if len(b.shape) == 3: - if batch and batch != b.shape[0]: + if len(a_shape) == 3: + batch = a_shape[0] + stride_a = a_strides[0] + if len(b_shape) == 3: + if batch and batch != b_shape[0]: raise ValueError('Batch size mismatch for matrix multiplication') - batch = b.shape[0] - stride_b = b.strides[0] - if c and len(c.shape) == 3: - if batch and batch != c.shape[0]: + batch = b_shape[0] + stride_b = b_strides[0] + if c_shape and len(c_shape) == 3: + if batch and batch != c_shape[0]: raise ValueError('Batch size mismatch for matrix multiplication') - batch = c.shape[0] - stride_c = c.strides[0] + batch = c_shape[0] + stride_c = c_strides[0] if batch is None: return {} @@ -76,9 +80,11 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, # Avoid import loops from dace.codegen.targets.common import sym2cpp - (_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg) - opt = get_gemm_opts(adesc, bdesc, cdesc) - bopt = get_batchmm_opts(adesc, bdesc, cdesc) + (_, _, ashape, astride), (_, _, bshape, + bstride) = _get_matmul_inputs(node, state, sdfg) + opt = get_gemm_opts(astride, bstride, cdesc.strides) + bopt = get_batchmm_opts(ashape, astride, bshape, bstride, cdesc.shape, + cdesc.strides) opt['x'] = '_a' opt['y'] = '_b' opt['M'] = sym2cpp(ashape[-2]) @@ -122,11 +128,12 @@ def make_sdfg(node, parent_state, parent_sdfg): sdfg = dace.SDFG(node.label + "_sdfg") state = sdfg.add_state(node.label + "_state") - ((edge_a, outer_array_a, shape_a), - (edge_b, outer_array_b, - shape_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg) + ((edge_a, outer_array_a, shape_a, strides_a), + (edge_b, outer_array_b, shape_b, + strides_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg) cdesc = parent_sdfg.arrays[parent_state.out_edges(node)[0].data.data] - bopt = get_batchmm_opts(outer_array_a, outer_array_b, cdesc) + bopt = get_batchmm_opts(shape_a, strides_a, shape_b, strides_b, + cdesc.shape, cdesc.strides) if shape_a[-1] != shape_b[-2]: raise SyntaxError('Matrix sizes must match') @@ -231,8 +238,9 @@ def expansion(node, state, sdfg): else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) - (_, adesc, ashape), (_, bdesc, - bshape) = _get_matmul_inputs(node, state, sdfg) + (_, adesc, ashape, + astrides), (_, bdesc, bshape, + bstrides) = _get_matmul_inputs(node, state, sdfg) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) @@ -413,11 +421,11 @@ def validate(self, sdfg, state): for _, _, _, dst_conn, memlet in state.in_edges(self): if dst_conn == '_a': subset = dc(memlet.subset) - #subset.squeeze() + subset.squeeze() size0 = subset.size() if dst_conn == '_b': subset = dc(memlet.subset) - #subset.squeeze() + subset.squeeze() size1 = subset.size() out_edges = state.out_edges(self) if len(out_edges) != 1: @@ -425,31 +433,27 @@ def validate(self, sdfg, state): "Expected exactly one output from matrix-matrix product") out_memlet = out_edges[0].data # Function is symmetric, edge order does not matter - bopt = get_batchmm_opts(sdfg.arrays[in_edges[0].data.data], - sdfg.arrays[in_edges[1].data.data], - sdfg.arrays[out_edges[0].data.data]) - if not bopt: - if len(size0) != 2: - raise ValueError( - "matrix-matrix product only supported on matrices") - if len(size1) != 2: - raise ValueError( - "matrix-matrix product only supported on matrices") + if len(size0) not in [2, 3]: + raise ValueError( + "matrix-matrix product only supported on matrices") + if len(size1) not in [2, 3]: + raise ValueError( + "matrix-matrix product only supported on matrices") if size0[-1] != size1[-2]: raise ValueError( "Inputs to matrix-matrix product must agree in the k-dimension" ) out_subset = dc(out_memlet.subset) - #out_subset.squeeze() + out_subset.squeeze() size2 = out_subset.size() if len(size2) not in [2, 3]: raise ValueError( "matrix-matrix product only supported on matrices") - if not bopt and list(size2) != [size0[-2], size1[-1]]: + if len(size2) == 2 and list(size2) != [size0[-2], size1[-1]]: raise ValueError( "Output to matrix-matrix product must agree in the m and n " "dimensions") - if bopt and list(size2) != [bopt['b'], size0[-2], size1[-1]]: - raise ValueError( - "Output to batch matrix-matrix product must agree in the b, " - "m, and n dimensions") + # if len(size2) == 3 and list(size2) != [bopt['b'], size0[-2], size1[-1]]: + # raise ValueError( + # "Output to batch matrix-matrix product must agree in the b, " + # "m, and n dimensions")