From 223a83108c649c0977758b806b83626da0f790ef Mon Sep 17 00:00:00 2001 From: Dongming Yang <50566938+domin1985@users.noreply.github.com> Date: Tue, 25 Aug 2020 08:08:41 +0800 Subject: [PATCH] [Frontend][Relay] Keras softmax and prelu fix (#6278) (#6278) * prelu and softmax with NHWC layout consideration * fix lint * fix lint Co-authored-by: Dongming Yang --- python/tvm/relay/frontend/keras.py | 21 +++++++++++++-------- tests/python/frontend/keras/test_forward.py | 1 + 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index 32de4718cc80b..b469ed0045a15 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -63,7 +63,7 @@ def _convert_recurrent_activation(inexpr, keras_layer): return _convert_activation(inexpr, act_type, None) -def _convert_activation(inexpr, keras_layer, _): +def _convert_activation(inexpr, keras_layer, etab): if isinstance(keras_layer, str): act_type = keras_layer else: @@ -80,7 +80,8 @@ def _convert_activation(inexpr, keras_layer, _): beta = _expr.const(beta, dtype='float32') return _op.add(_op.multiply(inexpr, alpha), beta) if act_type == 'softmax': - return _op.nn.softmax(inexpr, axis=1) + axis = 1 if etab.data_layout == 'NCHW' else -1 + return _op.nn.softmax(inexpr, axis) if act_type == 'sigmoid': return _op.sigmoid(inexpr) if act_type == 'tanh': @@ -123,10 +124,11 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): if isinstance(axis, list): raise tvm.error.OpAttributeUnImplemented( 'Softmax with axes {} is not supported.'.format(axis)) - if axis == -1: - axis = 1 - else: - axis = axis + 1 if axis < dims - 1 else 1 + if etab.data_layout == 'NCHW': + if axis == -1: + axis = 1 + else: + axis = axis + 1 if axis < dims - 1 else 1 return _op.nn.softmax(inexpr, axis=axis) if act_type == 'ReLU': threshold = _expr.const(keras_layer.threshold, dtype='float32') @@ -149,8 +151,11 @@ def _convert_advanced_activation(inexpr, keras_layer, etab): assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU." _check_data_format(keras_layer) size = len(keras_layer.alpha.shape) - alpha = etab.new_const(keras_layer.get_weights()[0] \ - .transpose(np.roll(range(size), 1))) + if etab.data_layout == 'NCHW': + alpha = etab.new_const(keras_layer.get_weights()[0] + .transpose(np.roll(range(size), 1))) + else: + alpha = etab.new_const(keras_layer.get_weights()[0]) return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr) if act_type == 'ThresholdedReLU': theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1. diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py index 8ddae9655d478..f9402554d53c7 100644 --- a/tests/python/frontend/keras/test_forward.py +++ b/tests/python/frontend/keras/test_forward.py @@ -182,6 +182,7 @@ def test_forward_activations(self, keras): x = act_func(data) keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model) + verify_keras_frontend(keras_model, need_transpose=False, layout='NHWC') def test_forward_dense(self, keras):