Skip to content

Commit

Permalink
Matmult: squeeze, use memlets instead of arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Mar 23, 2020
1 parent 1f37845 commit e8bbd38
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 49 deletions.
1 change: 0 additions & 1 deletion dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from dace.sdfg import SDFG, SDFGState
from dace.memlet import Memlet
from dace.frontend.common import op_repository as oprepo
from dace.libraries.blas.blas_helpers import to_blastype, get_gemm_opts


def _is_sequential(index_list):
Expand Down
3 changes: 2 additions & 1 deletion dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,8 @@ def _matmult(visitor, sdfg: SDFG, state: SDFGState, op1: str, op2: str):
from dace.libraries.blas.nodes.matmul import get_batchmm_opts

# Determine batched multiplication
bopt = get_batchmm_opts(arr1, arr2, None)
bopt = get_batchmm_opts(arr1.shape, arr1.strides, arr2.shape, arr2.strides,
None, None)
if bopt:
output_shape = (bopt['b'], arr1.shape[-2], arr2.shape[-1])
else:
Expand Down
8 changes: 4 additions & 4 deletions dace/libraries/blas/blas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def to_blastype(dtype):
dtype.__name__)


def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]:
def get_gemm_opts(a_strides, b_strides, c_strides) -> Dict[str, Any]:
"""
Returns GEMM argument order, transposition, and leading dimensions
based on column-major storage from dace arrays.
Expand All @@ -48,9 +48,9 @@ def get_gemm_opts(a: Array, b: Array, c: Array) -> Dict[str, Any]:
# | | |
# use these 3 to detect correct option

sAM, sAK = a.strides[-2:]
sBK, sBN = b.strides[-2:]
sCM, sCN = c.strides[-2:]
sAM, sAK = a_strides[-2:]
sBK, sBN = b_strides[-2:]
sCM, sCN = c_strides[-2:]

opts = {
'mkm': {
Expand Down
90 changes: 47 additions & 43 deletions dace/libraries/blas/nodes/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ def _get_matmul_inputs(node, state, sdfg):
for edge in state.in_edges(node):
if edge.dst_conn in ["_a", "_b"]:
subset = dc(edge.data.subset)
#subset.squeeze()
squeezed = subset.squeeze()
size = subset.size()
outer_array = sdfg.data(
dace.sdfg.find_input_arraynode(state, edge).data)
res = edge, outer_array, size
strides = [
s for i, s in enumerate(outer_array.strides) if i in squeezed
]
res = edge, outer_array, size, strides
if edge.dst_conn == "_a":
res_a = res
else:
Expand All @@ -31,7 +34,8 @@ def _get_matmul_inputs(node, state, sdfg):
return res_a, res_b


def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]:
def get_batchmm_opts(a_shape, a_strides, b_shape, b_strides, c_shape,
c_strides) -> Dict[str, Any]:
"""
Detects whether a matrix multiplication is a batched matrix multiplication
and returns its parameters (strides, batch size), or an empty dictionary if
Expand All @@ -42,27 +46,27 @@ def get_batchmm_opts(a: Array, b: Array, c: Optional[Array]) -> Dict[str, Any]:
:return: A dictionary with the following keys: sa,sb,sc (strides for a, b,
and c); and b (batch size).
"""
if len(a.shape) > 3 or len(b.shape) > 3 or (c and len(c.shape) > 3):
if len(a_shape) > 3 or len(b_shape) > 3 or (c_shape and len(c_shape) > 3):
raise ValueError('Tensor dimensions too large for (batched) matrix '
'multiplication')
if len(a.shape) <= 2 and len(b.shape) <= 2:
if len(a_shape) <= 2 and len(b_shape) <= 2:
return {}

batch = None
stride_a, stride_b, stride_c = 0, 0, 0
if len(a.shape) == 3:
batch = a.shape[0]
stride_a = a.strides[0]
if len(b.shape) == 3:
if batch and batch != b.shape[0]:
if len(a_shape) == 3:
batch = a_shape[0]
stride_a = a_strides[0]
if len(b_shape) == 3:
if batch and batch != b_shape[0]:
raise ValueError('Batch size mismatch for matrix multiplication')
batch = b.shape[0]
stride_b = b.strides[0]
if c and len(c.shape) == 3:
if batch and batch != c.shape[0]:
batch = b_shape[0]
stride_b = b_strides[0]
if c_shape and len(c_shape) == 3:
if batch and batch != c_shape[0]:
raise ValueError('Batch size mismatch for matrix multiplication')
batch = c.shape[0]
stride_c = c.strides[0]
batch = c_shape[0]
stride_c = c_strides[0]

if batch is None:
return {}
Expand All @@ -76,9 +80,11 @@ def _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta,
# Avoid import loops
from dace.codegen.targets.common import sym2cpp

(_, _, ashape), (_, _, bshape) = _get_matmul_inputs(node, state, sdfg)
opt = get_gemm_opts(adesc, bdesc, cdesc)
bopt = get_batchmm_opts(adesc, bdesc, cdesc)
(_, _, ashape, astride), (_, _, bshape,
bstride) = _get_matmul_inputs(node, state, sdfg)
opt = get_gemm_opts(astride, bstride, cdesc.strides)
bopt = get_batchmm_opts(ashape, astride, bshape, bstride, cdesc.shape,
cdesc.strides)
opt['x'] = '_a'
opt['y'] = '_b'
opt['M'] = sym2cpp(ashape[-2])
Expand Down Expand Up @@ -122,11 +128,12 @@ def make_sdfg(node, parent_state, parent_sdfg):
sdfg = dace.SDFG(node.label + "_sdfg")
state = sdfg.add_state(node.label + "_state")

((edge_a, outer_array_a, shape_a),
(edge_b, outer_array_b,
shape_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg)
((edge_a, outer_array_a, shape_a, strides_a),
(edge_b, outer_array_b, shape_b,
strides_b)) = _get_matmul_inputs(node, parent_state, parent_sdfg)
cdesc = parent_sdfg.arrays[parent_state.out_edges(node)[0].data.data]
bopt = get_batchmm_opts(outer_array_a, outer_array_b, cdesc)
bopt = get_batchmm_opts(shape_a, strides_a, shape_b, strides_b,
cdesc.shape, cdesc.strides)

if shape_a[-1] != shape_b[-2]:
raise SyntaxError('Matrix sizes must match')
Expand Down Expand Up @@ -231,8 +238,9 @@ def expansion(node, state, sdfg):
else:
raise ValueError("Unsupported type for BLAS dot product: " +
str(dtype))
(_, adesc, ashape), (_, bdesc,
bshape) = _get_matmul_inputs(node, state, sdfg)
(_, adesc, ashape,
astrides), (_, bdesc, bshape,
bstrides) = _get_matmul_inputs(node, state, sdfg)
cdesc = sdfg.arrays[state.out_edges(node)[0].data.data]
opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc,
alpha, beta, cdesc.dtype.ctype, func)
Expand Down Expand Up @@ -413,43 +421,39 @@ def validate(self, sdfg, state):
for _, _, _, dst_conn, memlet in state.in_edges(self):
if dst_conn == '_a':
subset = dc(memlet.subset)
#subset.squeeze()
subset.squeeze()
size0 = subset.size()
if dst_conn == '_b':
subset = dc(memlet.subset)
#subset.squeeze()
subset.squeeze()
size1 = subset.size()
out_edges = state.out_edges(self)
if len(out_edges) != 1:
raise ValueError(
"Expected exactly one output from matrix-matrix product")
out_memlet = out_edges[0].data
# Function is symmetric, edge order does not matter
bopt = get_batchmm_opts(sdfg.arrays[in_edges[0].data.data],
sdfg.arrays[in_edges[1].data.data],
sdfg.arrays[out_edges[0].data.data])
if not bopt:
if len(size0) != 2:
raise ValueError(
"matrix-matrix product only supported on matrices")
if len(size1) != 2:
raise ValueError(
"matrix-matrix product only supported on matrices")
if len(size0) not in [2, 3]:
raise ValueError(
"matrix-matrix product only supported on matrices")
if len(size1) not in [2, 3]:
raise ValueError(
"matrix-matrix product only supported on matrices")
if size0[-1] != size1[-2]:
raise ValueError(
"Inputs to matrix-matrix product must agree in the k-dimension"
)
out_subset = dc(out_memlet.subset)
#out_subset.squeeze()
out_subset.squeeze()
size2 = out_subset.size()
if len(size2) not in [2, 3]:
raise ValueError(
"matrix-matrix product only supported on matrices")
if not bopt and list(size2) != [size0[-2], size1[-1]]:
if len(size2) == 2 and list(size2) != [size0[-2], size1[-1]]:
raise ValueError(
"Output to matrix-matrix product must agree in the m and n "
"dimensions")
if bopt and list(size2) != [bopt['b'], size0[-2], size1[-1]]:
raise ValueError(
"Output to batch matrix-matrix product must agree in the b, "
"m, and n dimensions")
# if len(size2) == 3 and list(size2) != [bopt['b'], size0[-2], size1[-1]]:
# raise ValueError(
# "Output to batch matrix-matrix product must agree in the b, "
# "m, and n dimensions")

0 comments on commit e8bbd38

Please sign in to comment.