diff --git a/dace/frontend/common/einsum.py b/dace/frontend/common/einsum.py index 3956a2e0de..c08b424c96 100644 --- a/dace/frontend/common/einsum.py +++ b/dace/frontend/common/einsum.py @@ -10,7 +10,6 @@ from dace.sdfg import SDFG, SDFGState from dace.memlet import Memlet from dace.frontend.common import op_repository as oprepo -from dace.libraries.blas.blas_helpers import to_blastype, get_gemm_opts def _is_sequential(index_list): diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 58c12d3a74..7ce01a903e 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -682,7 +682,8 @@ def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str): from dace.libraries.blas.nodes.matmul import get_batchmm_opts # Determine batched multiplication - bopt = get_batchmm_opts(arr1, arr2, None) + bopt = get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides, + None, None) if bopt: output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1]) else: diff --git a/dace/libraries/blas/blas_helpers.py b/dace/libraries/blas/blas_helpers.py index 49e5bba32d..d3f9dd9c92 100644 --- a/dace/libraries/blas/blas_helpers.py +++ b/dace/libraries/blas/blas_helpers.py @@ -22,7 +22,7 @@ def to_blastype(dtype): dtype.__name__) -def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]: +def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]: """ Returns GEMM argument order, transposition, and leading dimensions based on column-major storage from dace arrays. @@ -48,9 +48,9 @@ def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]: # | | | # use these 3 to detect correct option - sAM, sAK = a.strides[-2:] - sBK, sBN = b.strides[-2:] - sCM, sCN = c.strides[-2:] + sAM, sAK = a_strides[-2:] + sBK, sBN = b_strides[-2:] + sCM, sCN = c_strides[-2:] opts = { 'mkm': { diff --git a/dace/libraries/blas/nodes/matmul.py b/dace/libraries/blas/nodes/matmul.py index 3a322b358e..687537486f 100644 --- a/dace/libraries/blas/nodes/matmul.py +++ b/dace/libraries/blas/nodes/matmul.py @@ -16,11 +16,14 @@ def _get_matmul_inputs(node, state, sdfg): for edge in state.in_edges(node): if edge.dst_conn in ["_a", "_b"]: subset = dc(edge.data.subset) - #subset.squeeze() + squeezed = subset.squeeze() size = subset.size() outer_array = sdfg.data( dace.sdfg.find_input_arraynode(state, edge).data) - res = edge, outer_array, size + strides = [ + s for i, s in enumerate(outer_array.strides) if i in squeezed + ] + res = edge, outer_array, size, strides if edge.dst_conn == "_a": res_a = res else: @@ -31,7 +34,8 @@ def _get_matmul_inputs(node, state, sdfg): return res_a, res_b -def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: +def get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape, + c_strides) -> Dict[str, Any]: """ Detects whether a matrix multiplication is a batched matrix multiplication and returns its parameters (strides, batch size), or an empty dictionary if @@ -42,27 +46,27 @@ def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]: :return: A dictionary with the following keys: sa,sb,sc (strides for a, b, and c); and b (batch size). """ - if len(a.shape) > 3 or len(b.shape) > 3 or (c and len(c.shape) > 3): + if len(a_shape) > 3 or len(b_shape) > 3 or (c_shape and len(c_shape) > 3): raise ValueError('Tensor dimensions too large for (batched) matrix ' 'multiplication') - if len(a.shape) <= 2 and len(b.shape) <= 2: + if len(a_shape) <= 2 and len(b_shape) <= 2: return {} batch = None stride_a, stride_b, stride_c = 0, 0, 0 - if len(a.shape) == 3: - batch = a.shape[0] - stride_a = a.strides[0] - if len(b.shape) == 3: - if batch and batch != b.shape[0]: + if len(a_shape) == 3: + batch = a_shape[0] + stride_a = a_strides[0] + if len(b_shape) == 3: + if batch and batch != b_shape[0]: raise ValueError('Batch size mismatch for matrix multiplication') - batch = b.shape[0] - stride_b = b.strides[0] - if c and len(c.shape) == 3: - if batch and batch != c.shape[0]: + batch = b_shape[0] + stride_b = b_strides[0] + if c_shape and len(c_shape) == 3: + if batch and batch != c_shape[0]: raise ValueError('Batch size mismatch for matrix multiplication') - batch = c.shape[0] - stride_c = c.strides[0] + batch = c_shape[0] + stride_c = c_strides[0] if batch is None: return {} @@ -76,9 +80,11 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, # Avoid import loops from dace.codegen.targets.common import sym2cpp - (_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg) - opt = get_gemm_opts(adesc, bdesc, cdesc) - bopt = get_batchmm_opts(adesc, bdesc, cdesc) + (_, _, ashape, astride), (_, _, bshape, + bstride) = _get_matmul_inputs(node, state, sdfg) + opt = get_gemm_opts(astride, bstride, cdesc.strides) + bopt = get_batchmm_opts(ashape, astride, bshape, bstride, cdesc.shape, + cdesc.strides) opt['x'] = '_a' opt['y'] = '_b' opt['M'] = sym2cpp(ashape[-2]) @@ -122,11 +128,12 @@ def make_sdfg(node, parent_state, parent_sdfg): sdfg = dace.SDFG(node.label + "_sdfg") state = sdfg.add_state(node.label + "_state") - ((edge_a, outer_array_a, shape_a), - (edge_b, outer_array_b, - shape_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg) + ((edge_a, outer_array_a, shape_a, strides_a), + (edge_b, outer_array_b, shape_b, + strides_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg) cdesc = parent_sdfg.arrays[parent_state.out_edges(node)[0].data.data] - bopt = get_batchmm_opts(outer_array_a, outer_array_b, cdesc) + bopt = get_batchmm_opts(shape_a, strides_a, shape_b, strides_b, + cdesc.shape, cdesc.strides) if shape_a[-1] != shape_b[-2]: raise SyntaxError('Matrix sizes must match') @@ -231,8 +238,9 @@ def expansion(node, state, sdfg): else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) - (_, adesc, ashape), (_, bdesc, - bshape) = _get_matmul_inputs(node, state, sdfg) + (_, adesc, ashape, + astrides), (_, bdesc, bshape, + bstrides) = _get_matmul_inputs(node, state, sdfg) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) @@ -413,11 +421,11 @@ def validate(self, sdfg, state): for _, _, _, dst_conn, memlet in state.in_edges(self): if dst_conn == '_a': subset = dc(memlet.subset) - #subset.squeeze() + subset.squeeze() size0 = subset.size() if dst_conn == '_b': subset = dc(memlet.subset) - #subset.squeeze() + subset.squeeze() size1 = subset.size() out_edges = state.out_edges(self) if len(out_edges) != 1: @@ -425,31 +433,27 @@ def validate(self, sdfg, state): "Expected exactly one output from matrix-matrix product") out_memlet = out_edges[0].data # Function is symmetric, edge order does not matter - bopt = get_batchmm_opts(sdfg.arrays[in_edges[0].data.data], - sdfg.arrays[in_edges[1].data.data], - sdfg.arrays[out_edges[0].data.data]) - if not bopt: - if len(size0) != 2: - raise ValueError( - "matrix-matrix product only supported on matrices") - if len(size1) != 2: - raise ValueError( - "matrix-matrix product only supported on matrices") + if len(size0) not in [2, 3]: + raise ValueError( + "matrix-matrix product only supported on matrices") + if len(size1) not in [2, 3]: + raise ValueError( + "matrix-matrix product only supported on matrices") if size0[-1] != size1[-2]: raise ValueError( "Inputs to matrix-matrix product must agree in the k-dimension" ) out_subset = dc(out_memlet.subset) - #out_subset.squeeze() + out_subset.squeeze() size2 = out_subset.size() if len(size2) not in [2, 3]: raise ValueError( "matrix-matrix product only supported on matrices") - if not bopt and list(size2) != [size0[-2], size1[-1]]: + if len(size2) == 2 and list(size2) != [size0[-2], size1[-1]]: raise ValueError( "Output to matrix-matrix product must agree in the m and n " "dimensions") - if bopt and list(size2) != [bopt['b'], size0[-2], size1[-1]]: - raise ValueError( - "Output to batch matrix-matrix product must agree in the b, " - "m, and n dimensions") + # if len(size2) == 3 and list(size2) != [bopt['b'], size0[-2], size1[-1]]: + # raise ValueError( + # "Output to batch matrix-matrix product must agree in the b, " + # "m, and n dimensions")