Skip to content

Commit

Permalink
[Frontend][Relay] Keras softmax and prelu fix (apache#6278) (apache#6278
Browse files Browse the repository at this point in the history
)

* prelu and softmax with NHWC layout consideration

* fix lint

* fix lint

Co-authored-by: Dongming Yang <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed Sep 3, 2020
1 parent 2010b8a commit a12d23e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
21 changes: 13 additions & 8 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _convert_recurrent_activation(inexpr, keras_layer):
return _convert_activation(inexpr, act_type, None)


def _convert_activation(inexpr, keras_layer, _):
def _convert_activation(inexpr, keras_layer, etab):
if isinstance(keras_layer, str):
act_type = keras_layer
else:
Expand All @@ -80,7 +80,8 @@ def _convert_activation(inexpr, keras_layer, _):
beta = _expr.const(beta, dtype='float32')
return _op.add(_op.multiply(inexpr, alpha), beta)
if act_type == 'softmax':
return _op.nn.softmax(inexpr, axis=1)
axis = 1 if etab.data_layout == 'NCHW' else -1
return _op.nn.softmax(inexpr, axis)
if act_type == 'sigmoid':
return _op.sigmoid(inexpr)
if act_type == 'tanh':
Expand Down Expand Up @@ -123,10 +124,11 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
if isinstance(axis, list):
raise tvm.error.OpAttributeUnImplemented(
'Softmax with axes {} is not supported.'.format(axis))
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims - 1 else 1
if etab.data_layout == 'NCHW':
if axis == -1:
axis = 1
else:
axis = axis + 1 if axis < dims - 1 else 1
return _op.nn.softmax(inexpr, axis=axis)
if act_type == 'ReLU':
threshold = _expr.const(keras_layer.threshold, dtype='float32')
Expand All @@ -149,8 +151,11 @@ def _convert_advanced_activation(inexpr, keras_layer, etab):
assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
_check_data_format(keras_layer)
size = len(keras_layer.alpha.shape)
alpha = etab.new_const(keras_layer.get_weights()[0] \
.transpose(np.roll(range(size), 1)))
if etab.data_layout == 'NCHW':
alpha = etab.new_const(keras_layer.get_weights()[0]
.transpose(np.roll(range(size), 1)))
else:
alpha = etab.new_const(keras_layer.get_weights()[0])
return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
if act_type == 'ThresholdedReLU':
theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
Expand Down
1 change: 1 addition & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_forward_activations(self, keras):
x = act_func(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
verify_keras_frontend(keras_model, need_transpose=False, layout='NHWC')


def test_forward_dense(self, keras):
Expand Down

0 comments on commit a12d23e

Please sign in to comment.