Skip to content

Commit

Permalink
[Relay][Frontend][Onnx] GRU Layer Support (#6020)
Browse files Browse the repository at this point in the history
* GRU debugging and testing added to onnx frontend.

* All tests working and code formatted.

* Fix lint issues.

* Add a test case and changed RNN argument parsing.

* Small refactor.
  • Loading branch information
jwfromm authored Jul 12, 2020
1 parent d6ceba0 commit 9f7745e
Show file tree
Hide file tree
Showing 2 changed files with 366 additions and 85 deletions.
140 changes: 131 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(self):

def __getitem__(self, item):
if isinstance(item, int):
if item > (len(self.input_keys) - 1):
return None
return self.input_dict[self.input_keys[item]]
if isinstance(item, str):
if item not in self.input_keys:
Expand Down Expand Up @@ -1493,8 +1495,8 @@ def expand_shape(in_shape, shape):
return _op.broadcast_to(inputs[0], shape=tuple(shape))


class LSTM(OnnxOpConverter):
""" Operator converter for LSTM.
class RNN(OnnxOpConverter):
""" Operator converter for RNNs such as LSTM and GRU.
"""

@classmethod
Expand Down Expand Up @@ -1528,18 +1530,23 @@ def _activation_needs_beta(cls, activation):
]
return activation.decode("utf-8") in needs_beta


class LSTM(RNN):
"""Operator converter for LSTM
"""

@classmethod
def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
R = inputs[2]
B = inputs['B']
B = inputs[3]
# Sequence length currently unused as it can be inferred from shapes.
#sequence_lens = inputs['sequence_lens']
h_0 = inputs['initial_h']
c_0 = inputs['initial_c']
P = inputs['P']
h_0 = inputs[5]
c_0 = inputs[6]
P = inputs[7]

num_directions = infer_shape(W)[0]
W_dtype = infer_type(W).type_annotation.dtype
Expand Down Expand Up @@ -1577,7 +1584,8 @@ def _impl_v7(cls, inputs, attr, params):
if 'activations' in attr:
activations = attr['activations']
if len(activations) != 3:
raise NotImplementedError("LSTM assumes 3 activation functions are provided")
raise NotImplementedError(
"LSTM assumes 3 activation functions are provided")
alpha_loc = 0
alphas = attr.get('activation_alpha', [])
if isinstance(alphas, float):
Expand All @@ -1591,10 +1599,12 @@ def _impl_v7(cls, inputs, attr, params):
alpha = None
beta = None
activation = activations[i]
if cls._activation_needs_alpha(activation) and len(alphas) > alpha_loc:
if cls._activation_needs_alpha(
activation) and len(alphas) > alpha_loc:
alpha = alphas[alpha_loc]
alpha_loc += 1
if cls._activation_needs_beta(activation) and len(betas) > beta_loc:
if cls._activation_needs_beta(
activation) and len(betas) > beta_loc:
beta = betas[beta_loc]
beta_loc += 1
acts.append(cls._activation_helper(activation, alpha, beta))
Expand Down Expand Up @@ -1638,6 +1648,117 @@ def _impl_v7(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple((output, H_t, C_t)), 3)


class GRU(RNN):
"""Operator convert for GRU
"""

@classmethod
def _impl_v7(cls, inputs, attr, params):
# Unpack inputs, note that if optional and not provided then value will be None.
X = inputs[0]
W = inputs[1]
R = inputs[2]
B = inputs[3]
# Sequence length currently unused as it can be inferred from shapes.
#sequence_lens = inputs['sequence_lens']
h_0 = inputs[5]
linear_before_reset = attr.get('linear_before_reset', 0)

num_directions = infer_shape(W)[0]
W_dtype = infer_type(W).type_annotation.dtype

if num_directions != 1:
raise NotImplementedError("Bidirectional GRUs not yet supported.")
# Remove num_directions axis from weights.
W = _op.squeeze(W, axis=[0])
R = _op.squeeze(R, axis=[0])
if B is not None:
B = _op.squeeze(B, axis=[0])

X_shape = infer_shape(X)
hidden_size = infer_shape(R)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if h_0 is None:
h_0 = _op.zeros((batch_size, hidden_size), W_dtype)
else:
h_0 = _op.squeeze(h_0, axis=[0])

H_t = h_0
h_list = []

if 'activations' in attr:
activations = attr['activations']
if len(activations) != 2:
raise NotImplementedError(
"GRU assumes 2 activation functions are provided")
alpha_loc = 0
alphas = attr.get('activation_alpha', [])
if isinstance(alphas, float):
alphas = [alphas]
beta_loc = 0
betas = attr.get('activation_beta', [])
if isinstance(betas, float):
betas = [betas]
acts = []
for i in range(2):
alpha = None
beta = None
activation = activations[i]
if cls._activation_needs_alpha(
activation) and len(alphas) > alpha_loc:
alpha = alphas[alpha_loc]
alpha_loc += 1
if cls._activation_needs_beta(
activation) and len(betas) > beta_loc:
beta = betas[beta_loc]
beta_loc += 1
acts.append(cls._activation_helper(activation, alpha, beta))
f_act, g_act = acts
else:
f_act = _op.sigmoid
g_act = _op.tanh

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
for step in X_steps:
step = _op.squeeze(step, axis=[0])
current = _op.nn.dense(step, W)
cz, cr, ch = _op.split(current, 3, axis=1)
rz, rr, rh = _op.split(R, 3, axis=0)
z = cz + _op.nn.dense(H_t, rz)
r = cr + _op.nn.dense(H_t, rr)
if B is not None:
WB, RB = _op.split(B, 2)
wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
z += wbz + rbz
r += wbr + rbr
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
else:
h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
else:
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh)))
else:
h = ch + _op.nn.dense((r * H_t), rh)

z = f_act(z)
r = f_act(r)
h = g_act(h)

H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t)
h_list.append(_op.expand_dims(H_t, axis=0))
# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H_t)), 2)


class Resize(OnnxOpConverter):
"""Operator converter for Resize
"""
Expand Down Expand Up @@ -1859,6 +1980,7 @@ def _get_convert_map(opset):
'LRN': LRN.get_converter(opset),
# Recurrent Layers
'LSTM': LSTM.get_converter(opset),
'GRU': GRU.get_converter(opset),

# defs/vision
'MaxRoiPool': MaxRoiPool.get_converter(opset),
Expand Down
Loading

0 comments on commit 9f7745e

Please sign in to comment.