Skip to content

Commit

Permalink
[Frontend][ONNX] LSTM Support (apache#4825)
Browse files Browse the repository at this point in the history
* Initial version working and passing tests.

* WIP on supporting other activations.

* add support for multiple activation functions in lstm

* All tests working and code cleaned up.

* Undo import swap to avoid conflict with masahi.

* Added new tests and related bug fixes.

Co-authored-by: Matthew Brookhart <[email protected]>
  • Loading branch information
2 people authored and alexwong committed Feb 26, 2020
1 parent 80633a8 commit 2886f9a
Show file tree
Hide file tree
Showing 2 changed files with 393 additions and 6 deletions.
223 changes: 217 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,55 @@
__all__ = ['from_onnx']


class onnx_input():
""" Dual purpose list or dictionary access object."""

def __init__(self):
self.input_keys = []
self.input_dict = {}

def __getitem__(self, item):
if isinstance(item, int):
return self.input_dict[self.input_keys[item]]
if isinstance(item, str):
if item not in self.input_keys:
return None
return self.input_dict[item]
if isinstance(item, slice):
keys = self.input_keys[item]
return [self.input_dict[key] for key in keys]

raise ValueError("Only integer, string, and slice accesses allowed.")

def __setitem__(self, item, value):
if isinstance(item, int):
self.input_dict[self.input_keys[item]] = value
elif isinstance(item, str):
if item not in self.input_dict:
self.input_keys.append(item)
self.input_dict[item] = value
else:
raise ValueError("Only integer and string indexed writes allowed.")

def keys(self):
return self.input_keys

def __len__(self):
return len(self.input_keys)

def __iter__(self):
self.n = 0
return self

def __next__(self):
if self.n < len(self.input_keys):
output = self.input_dict[self.input_keys[self.n]]
self.n += 1
return output

raise StopIteration


def get_numpy(tensor_proto):
"""Grab data in TensorProto and convert to numpy array."""
try:
Expand Down Expand Up @@ -664,13 +713,24 @@ def _impl_v1(cls, inputs, attr, params):
return inputs[len(inputs) - 1]


class Affine(OnnxOpConverter):
""" Operator converter for Affine transformation.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = _expr.const(attr.get('alpha', 1.0))
beta = _expr.const(attr.get('beta', 0.0))
return (alpha * inputs[0]) + beta


class ThresholdedRelu(OnnxOpConverter):
""" Operator converter for ThresholdedRelu.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
alpha = float(attr.get('alpha', 0.0))
alpha = float(attr.get('alpha', 1.0))
alpha_tensor = _op.full_like(inputs[0], fill_value=_expr.const(alpha))
mask = _op.greater(inputs[0], alpha_tensor).astype("float32")
return inputs[0] * mask
Expand Down Expand Up @@ -893,7 +953,7 @@ class Maximum(OnnxOpConverter):
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
_max = inputs[0]
for i in range(1, len(inputs)):
Expand All @@ -905,7 +965,7 @@ class Minimum(OnnxOpConverter):
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
_min = inputs[0]
for i in range(1, len(inputs)):
Expand All @@ -917,7 +977,7 @@ class Mean(OnnxOpConverter):
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if not isinstance(inputs, list) or len(inputs) < 2:
if not isinstance(inputs, (list, onnx_input)) or len(inputs) < 2:
raise ValueError("Expect minimum 2 inputs")
# avoid overflow
concat = _op.concatenate([_op.expand_dims(x, axis=0) for x in inputs], axis=0)
Expand Down Expand Up @@ -1190,6 +1250,151 @@ def expand_shape(in_shape, shape):
return _op.broadcast_to(inputs[0], shape=tuple(shape))


class LSTM(OnnxOpConverter):
""" Operator converter for LSTM.
"""

@classmethod
def _activation_helper(cls, activation, alpha, beta):
convert_map = _get_convert_map(1)
attrs = {}
if alpha is not None:
attrs['alpha'] = alpha
if beta is not None:
attrs['beta'] = beta
return lambda x: convert_map[activation.decode("utf-8")]([x], attrs, {})

@classmethod
def _activation_needs_alpha(cls, activation):
needs_alpha = [
"Affine",
"LeakyRelu",
"ThresholdedRelu",
"ScaledTanh",
"HardSigmoid",
"Elu",
]
return activation.decode("utf-8") in needs_alpha

@classmethod
def _activation_needs_beta(cls, activation):
needs_beta = [
"Affine",
"ScaledTanh",
"HardSigmoid",
]
return activation.decode("utf-8") in needs_beta

@classmethod
def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
R = inputs[2]
B = inputs['B']
# Sequence length currently unused as it can be inferred from shapes.
#sequence_lens = inputs['sequence_lens']
h_0 = inputs['initial_h']
c_0 = inputs['initial_c']
P = inputs['P']

num_directions = infer_shape(W)[0]
W_dtype = infer_type(W).type_annotation.dtype

if num_directions != 1:
raise NotImplementedError("Bidirectional LSTMs not yet supported.")
# Remove num_directions axis from weights.
W = _op.squeeze(W, axis=[0])
R = _op.squeeze(R, axis=[0])
if B is not None:
B = _op.squeeze(B, axis=[0])

X_shape = infer_shape(X)
hidden_size = infer_shape(R)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if h_0 is None:
h_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
h_0 = _op.squeeze(h_0, axis=[0])
if c_0 is None:
c_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
c_0 = _op.squeeze(c_0, axis=[0])

if P is not None:
P = _op.squeeze(P, axis=[0])
p_i, p_o, p_f = _op.split(P, 3)
H_t = h_0
C_t = c_0
h_list = []

if 'activations' in attr:
activations = attr['activations']
if len(activations) != 3:
raise NotImplementedError("LSTM assumes 3 activation functions are provided")
alpha_loc = 0
alphas = attr.get('activation_alpha', [])
if isinstance(alphas, float):
alphas = [alphas]
beta_loc = 0
betas = attr.get('activation_beta', [])
if isinstance(betas, float):
betas = [betas]
acts = []
for i in range(3):
alpha = None
beta = None
activation = activations[i]
if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
alpha = alphas[alpha_loc]
alpha_loc += 1
if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
beta = betas[beta_loc]
beta_loc += 1
acts.append(cls._activation_helper(activation, alpha, beta))
f_act, g_act, h_act = acts
else:
f_act = _op.sigmoid
g_act = _op.tanh
h_act = _op.tanh

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
for step in X_steps:
step = _op.squeeze(step, axis=[0])
gates = _op.nn.dense(step, W) + _op.nn.dense(H_t, R)
if B is not None:
WB, RB = _op.split(B, 2)
gates += WB + RB
i, o, f, c = _op.split(gates, 4, axis=-1)
if P is not None:
i = f_act(i + p_i * C_t)
f = f_act(f + p_f * C_t)

else:
i = f_act(i)
f = f_act(f)
c = g_act(c)
C = f * C_t + i * c
if P is not None:
o = f_act(o + p_o * C)
else:
o = f_act(o)
H = o * h_act(C)
H_t = H
C_t = C
h_list.append(_op.expand_dims(H, axis=0))
# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
C_t = _op.expand_dims(C_t, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3)


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand All @@ -1203,7 +1408,7 @@ def _get_convert_map(opset):
return {
# defs/experimental
'Identity': Renamer('copy'),
# 'Affine'
'Affine': Affine.get_converter(opset),
'ThresholdedRelu': ThresholdedRelu.get_converter(opset),
'ScaledTanh': ScaledTanh.get_converter(opset),
'ParametricSoftplus': ParametricSoftPlus.get_converter(opset),
Expand Down Expand Up @@ -1281,6 +1486,8 @@ def _get_convert_map(opset):
'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']),
'Flatten': Flatten.get_converter(opset),
'LRN': LRN.get_converter(opset),
# Recurrent Layers
'LSTM': LSTM.get_converter(opset),

# defs/reduction
'ReduceMax': ReduceMax.get_converter(opset),
Expand Down Expand Up @@ -1414,7 +1621,11 @@ def from_onnx(self, graph, opset):
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
inputs = [self._nodes[self._renames.get(i, i)] for i in node.input]
# Create and populate onnx input object.
inputs = onnx_input()
for i in node.input:
if i != '':
inputs[i] = self._nodes[self._renames.get(i, i)]
if op_name == "Constant":
t_proto = self._parse_attr(node.attribute)["value"]
self._num_param += 1
Expand Down
Loading

0 comments on commit 2886f9a

Please sign in to comment.