Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einsum support and minor SDFG API updates #172

Merged
merged 33 commits into from
Mar 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
fcd57ec
Fix pytest-confusing syntax
tbennun Mar 14, 2020
134e0eb
Zero memory for return values on the python side (to be safe)
tbennun Mar 14, 2020
a30b99c
Fix add_mapped_tasklet API for disjoint input/output_nodes and connec…
tbennun Mar 15, 2020
c3f1316
New API for adding a state before/after an existing state
tbennun Mar 15, 2020
dba0f1c
When inlining SDFGs, try to squeeze internal memlets first if dimensi…
tbennun Mar 15, 2020
00cb92c
Clean up dependencies vs. test dependencies
tbennun Mar 15, 2020
e838e1a
Einsum function replacement with opt_einsum optimizer and GEMM specia…
tbennun Mar 15, 2020
3f94010
Update setup.py
tbennun Mar 15, 2020
20eac1c
Update Jenkinsfile
tbennun Mar 15, 2020
c7dfbb9
Update .travis.yml
tbennun Mar 15, 2020
3f2f795
Merge branch 'master' into einsum
tbennun Mar 15, 2020
184964c
Refactor BLAS helpers into a file
tbennun Mar 16, 2020
5f9fe72
Fix minor issues in einsum
tbennun Mar 16, 2020
37eae5f
CUBLAS matrix multiplication + test
tbennun Mar 16, 2020
661560e
Added matmul test to CI
tbennun Mar 16, 2020
f1266f0
matmul tests are more exhaustive (data layout, types)
tbennun Mar 16, 2020
4a1fe4a
Fix for unsupported case in CI
tbennun Mar 16, 2020
f4f4680
Increase test timeout
tbennun Mar 16, 2020
021cf32
Minor update
tbennun Mar 17, 2020
ba49c24
Add strided-batched matrix multiplication support to CUBLAS MatMult
tbennun Mar 22, 2020
0a49966
GPUTransformSDFG: Do not transform library nodes
tbennun Mar 22, 2020
3948ce8
Python frontend: Support batched matrix multiplication with @ operator
tbennun Mar 22, 2020
441c10d
Batched matrix multiplication CUDA test
tbennun Mar 22, 2020
48a2532
Matrix multiplication operator test
tbennun Mar 22, 2020
49c57ce
Pure implementation of batched matrix mutliplication
tbennun Mar 22, 2020
6565cfe
MKL version of matrix multiplication node
tbennun Mar 22, 2020
d742adf
MatMult: Fix code generation symbol output
tbennun Mar 22, 2020
4cc2836
CUBLAS MatMult: Copy data to GPU if not already there
tbennun Mar 22, 2020
7600196
Einsum: Use library nodes when possible
tbennun Mar 22, 2020
5bc6381
Adapt einsum replacement to work with dot products
tbennun Mar 22, 2020
bc108f2
InlineSDFG: Don't inline in strict mode if more dimensions are inside…
tbennun Mar 22, 2020
1f37845
Remove GPU storage in einsum
tbennun Mar 22, 2020
e8bbd38
Matmult: squeeze, use memlets instead of arrays
tbennun Mar 23, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ before_install:
- sudo apt-get install libpapi-dev papi-tools

install:
- pip install ".[test]" .
- pip install ".[testing]" .
- pip install coverage codecov
- if [ $DACE_optimizer_automatic_strict_transformations -eq 1 ]; then pip install tensorflow==1.15.0; fi

Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pipeline {
echo "Installing additional dependencies"
pip3 install --upgrade --user tensorflow-gpu==1.14.0
echo "Installing DaCe"
pip3 install --ignore-installed --upgrade --user ".[test]" .
pip3 install --ignore-installed --upgrade --user ".[testing]" .
pip3 install --user cmake
pip3 install --user coverage
'''
Expand Down
2 changes: 1 addition & 1 deletion dace/codegen/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _initialize_return_values(self, kwargs):
self._return_arrays.append(
np.ndarray([symbolic.evaluate(s, syms) for s in arr.shape],
arr.dtype.type,
buffer=np.ndarray(
buffer=np.zeros(
[symbolic.evaluate(arr.total_size, syms)],
arr.dtype.type),
strides=[
Expand Down
1 change: 1 addition & 0 deletions dace/frontend/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .op_impl import constant_array_multiplication
from .op_impl import matrix_transpose, matrix_transpose_s
from .op_impl import matrix_pointwise_op
from .einsum import create_einsum_sdfg
351 changes: 351 additions & 0 deletions dace/frontend/common/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,351 @@
""" Classes to handle Einstein-notation sums (einsum) as a library node. """
from functools import reduce
from itertools import chain
from string import ascii_letters
from typing import Dict, Optional

from dace import dtypes, symbolic
from dace.graph.nodes import AccessNode
from dace.graph.edges import InterstateEdge
from dace.sdfg import SDFG, SDFGState
from dace.memlet import Memlet
from dace.frontend.common import op_repository as oprepo


def _is_sequential(index_list):
if not index_list:
return True
index_list = sorted(index_list)
smallest_elem = index_list[0]
return index_list == list(
range(smallest_elem, smallest_elem + len(index_list)))


class EinsumParser(object):
""" String parser for einsum. """
def __init__(self, string):
inout = string.split('->')
if len(inout) == 1:
inputs, output = string, ''
else:
inputs, output = inout

for char in chain(inputs, output):
if char not in ascii_letters + ',':
raise ValueError(
'Invalid einsum string, subscript must contain'
' letters, commas, and "->".')

inputs = inputs.split(',')

# No output given, assumed all "free" subscripts in inputs
if len(inout) == 1:
# Find intersection and union of all inputs for the non-outputs
# and free inputs
nonfree = set()
free = set()
for i, inp in enumerate(inputs):
for var in set(inp):
if (all(var not in set(s) for s in inputs[i + 1:])
and var not in nonfree):
free.add(var)
else:
nonfree.add(var)
output = ''.join(sorted(free))

self.inputs = inputs
self.output = output
if len(inputs) != 2:
return

# Special case: contracting two tensors
a, b = inputs
c = output
a_vars = set(a)
b_vars = set(b)
ab_vars = a_vars.union(b_vars)
c_vars = set(c)
if not ab_vars.issuperset(c_vars):
raise ValueError('Einsum subscript string includes outputs that do'
' not appear as an input')

batch_vars = a_vars.intersection(b_vars).intersection(c_vars)
sum_vars = a_vars.intersection(b_vars) - c_vars
a_only_vars = a_vars - sum_vars - batch_vars
b_only_vars = b_vars - sum_vars - batch_vars

self.a_batch = [i for i, d in enumerate(a) if d in batch_vars]
self.a_sum = [i for i, d in enumerate(a) if d in sum_vars]
self.a_only = [i for i, d in enumerate(a) if d in a_only_vars]

self.b_batch = [i for i, d in enumerate(b) if d in batch_vars]
self.b_sum = [i for i, d in enumerate(b) if d in sum_vars]
self.b_only = [i for i, d in enumerate(b) if d in b_only_vars]

self.c_a_only = [i for i, d in enumerate(c) if d in a_only_vars]
self.c_b_only = [i for i, d in enumerate(c) if d in b_only_vars]
self.c_batch = [i for i, d in enumerate(c) if d in batch_vars]

def is_bmm(self):
if len(self.inputs) != 2:
return False
for key, val in self.fields().items():
if not _is_sequential(val):
return False
return True

def fields(self):
return {
fname: fval
for fname, fval in self.__dict__.items()
if fname not in ('inputs', 'output')
}

def __str__(self):
return str(self.__dict__)

def __repr__(self):
return str(self)


def create_batch_gemm_sdfg(dtype, strides):
#########################
sdfg = SDFG('einsum')
state = sdfg.add_state()
M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N'])
BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = (
symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s]
for s in [
'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN',
'sCB'
])

batched = strides['BATCH'] != 1

_, xarr = sdfg.add_array(
'X',
dtype=dtype,
shape=[BATCH, M, K] if batched else [M, K],
strides=[sAB, sAM, sAK] if batched else [sAM, sAK])
_, yarr = sdfg.add_array(
'Y',
dtype=dtype,
shape=[BATCH, K, N] if batched else [K, N],
strides=[sBB, sBK, sBN] if batched else [sBK, sBN])
_, zarr = sdfg.add_array(
'Z',
dtype=dtype,
shape=[BATCH, M, N] if batched else [M, N],
strides=[sCB, sCM, sCN] if batched else [sCM, sCN])

gX = state.add_read('X')
gY = state.add_read('Y')
gZ = state.add_write('Z')

import dace.libraries.blas as blas # Avoid import loop

libnode = blas.MatMul('einsum_gemm', zarr.dtype)
state.add_node(libnode)
state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr))
state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr))
state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr))

return sdfg


def prod(iterable):
return reduce(lambda x, y: x * y, iterable, 1)


@oprepo.replaces('numpy.einsum')
def create_einsum_sdfg(sdfg: SDFG,
state: SDFGState,
einsum_string: str,
*arrays: str,
dtype: Optional[dtypes.typeclass] = None,
optimize: bool = False,
output: Optional[str] = None):
return _create_einsum_internal(sdfg,
state,
einsum_string,
*arrays,
dtype=dtype,
optimize=optimize,
output=output)[0]


def _create_einsum_internal(sdfg: SDFG,
state: SDFGState,
einsum_string: str,
*arrays: str,
dtype: Optional[dtypes.typeclass] = None,
optimize: bool = False,
output: Optional[str] = None,
nodes: Optional[Dict[str, AccessNode]] = None):
# Infer shapes and strides of input/output arrays
einsum = EinsumParser(einsum_string)

if len(einsum.inputs) != len(arrays):
raise ValueError('Invalid number of arrays for einsum expression')

# Get shapes from arrays and verify dimensionality
chardict = {}
for inp, inpname in zip(einsum.inputs, arrays):
inparr = sdfg.arrays[inpname]
if len(inp) != len(inparr.shape):
raise ValueError('Dimensionality mismatch in input "%s"' % inpname)
for char, shp in zip(inp, inparr.shape):
if char in chardict and shp != chardict[char]:
raise ValueError('Dimension mismatch in einsum expression')
chardict[char] = shp

if optimize:
# Try to import opt_einsum
try:
import opt_einsum as oe
except (ModuleNotFoundError, NameError, ImportError):
raise ImportError('To optimize einsum expressions, please install '
'the "opt_einsum" package.')

for char, shp in chardict.items():
if symbolic.issymbolic(shp):
raise ValueError('Einsum optimization cannot be performed '
'on symbolically-sized array dimension "%s" '
'for subscript character "%s"' % (shp, char))

# Create optimal contraction path
# noinspection PyTypeChecker
_, path_info = oe.contract_path(
einsum_string, *oe.helpers.build_views(einsum_string, chardict))

input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}
result_node = None

# Follow path and create a chain of operation SDFG states
for pair, nonfree, expr, after, blas in path_info.contraction_list:
result, result_node = _create_einsum_internal(sdfg,
state,
expr,
arrays[pair[0]],
arrays[pair[1]],
dtype=dtype,
optimize=False,
output=None,
nodes=input_nodes)
arrays = ([a for i, a in enumerate(arrays) if i not in pair] +
[result])
input_nodes[result] = result_node

return arrays[0], result_node
# END of einsum optimization

input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}

# Get output shape from chardict, or [1] for a scalar output
output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1]
output_index = ','.join(o for o in einsum.output) or '0'

if output is None:
dtype = dtype or sdfg.arrays[arrays[0]].dtype
output, odesc = sdfg.add_temp_transient(output_shape, dtype)
to_init = True
else:
odesc = sdfg.arrays[output]
dtype = dtype or odesc.dtype
to_init = False

if not einsum.is_bmm():
# Fall back to "pure" SDFG einsum with conflict resolution
c = state.add_write(output)

# Add state before this one to initialize the output value
if to_init:
init_state = sdfg.add_state_before(state)
if len(einsum.output) > 0:
init_state.add_mapped_tasklet(
'einsum_reset',
{k: '0:%s' % chardict[k]
for k in einsum.output}, {},
'out_%s = 0' % output,
{'out_%s' % output: Memlet.simple(output, output_index)},
external_edges=True)
else: # Scalar output
t = init_state.add_tasklet('einsum_reset', set(),
{'out_%s' % output},
'out_%s = 0' % output)
onode = init_state.add_write(output)
init_state.add_edge(t, 'out_%s' % output, onode, None,
Memlet.simple(output, '0'))

# Pure einsum map
state.add_mapped_tasklet(
'einsum', {k: '0:%s' % v
for k, v in chardict.items()}, {
'inp_%s' % arr: Memlet.simple(arr, ','.join(inp))
for inp, arr in zip(einsum.inputs, arrays)
},
'out_%s = %s' % (output, ' * '.join('inp_%s' % arr
for arr in arrays)),
{
'out_%s' % output:
Memlet.simple(output, output_index, wcr_str='lambda a,b: a+b')
},
input_nodes=input_nodes,
output_nodes={output: c},
external_edges=True)
else:
# Represent einsum as a GEMM or batched GEMM (using library nodes)
a_shape = sdfg.arrays[arrays[0]].shape
b_shape = sdfg.arrays[arrays[1]].shape
c_shape = output_shape

a = input_nodes[arrays[0]]
b = input_nodes[arrays[1]]
c = state.add_write(output)

# Compute GEMM dimensions and strides
strides = dict(
BATCH=prod([c_shape[dim] for dim in einsum.c_batch]),
M=prod([a_shape[dim] for dim in einsum.a_only]),
K=prod([a_shape[dim] for dim in einsum.a_sum]),
N=prod([b_shape[dim] for dim in einsum.b_only]),
sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1,
sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1,
sAB=prod(a_shape[einsum.a_batch[-1] +
1:]) if einsum.a_batch else 1,
sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1,
sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1,
sBB=prod(b_shape[einsum.b_batch[-1] +
1:]) if einsum.b_batch else 1,
sCM=prod(c_shape[einsum.c_a_only[-1] +
1:]) if einsum.c_a_only else 1,
sCN=prod(c_shape[einsum.c_b_only[-1] +
1:]) if einsum.c_b_only else 1,
sCB=prod(c_shape[einsum.c_batch[-1] +
1:]) if einsum.c_batch else 1)

# Complement strides to make matrices as necessary
if len(a_shape) == 1 and len(einsum.a_sum) == 1:
strides['sAK'] = 1
strides['sAB'] = strides['sAM'] = strides['K']
if len(b_shape) == 1 and len(einsum.b_sum) == 1:
strides['sBN'] = 1
strides['sBK'] = 1
strides['sBB'] = strides['K']
if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum):
strides['sCN'] = 1
strides['sCB'] = strides['sCM'] = strides['N']

# Create nested SDFG for GEMM
nsdfg = create_batch_gemm_sdfg(dtype, strides)

nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'},
strides)
state.add_edge(a, None, nsdfg_node, 'X',
Memlet.from_array(a.data, a.desc(sdfg)))
state.add_edge(b, None, nsdfg_node, 'Y',
Memlet.from_array(b.data, b.desc(sdfg)))
state.add_edge(nsdfg_node, 'Z', c, None,
Memlet.from_array(c.data, c.desc(sdfg)))

return output, c
Loading