From 9e3644d5fff8a6f6ff2999305ec00480d896d856 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Fri, 15 Nov 2019 11:43:38 +0800 Subject: [PATCH] Solve custom model of prelu (#4326) --- python/tvm/relay/frontend/tflite.py | 3 +-- tests/python/frontend/tflite/test_forward.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 30209ff448f55..415f04eff52c1 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -1340,14 +1340,13 @@ def convert_prelu(self, op): alpha_tensor = input_tensors[1] alpha_tensor_type = alpha_tensor.tensor.Type() alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type) - alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor), + alpha_expr = self.exp_tab.new_const(self.get_tensor_value(alpha_tensor).flatten(), dtype=alpha_tensor_type_str) in_expr = self.get_expr(input_tensor.tensor_idx) out = _op.nn.prelu(in_expr, alpha_expr, axis=3) return out - def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 83a0730d74f87..8d1902694ee4b 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -934,18 +934,18 @@ def test_forward_relu(): """ ReLU """ _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6))) -def _test_prelu(data): +def _test_prelu(data, alpha): """ One iteration of PReLU """ with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) - alpha = np.full((data.shape[-1],), 0.2, dtype=data.dtype) # This specific pattern will be replaced into PRelu by tflite out = nn_ops.relu(in_data) + (-alpha * nn_ops.relu(-in_data)) compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) def test_forward_prelu(): """ PReLU """ - _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32")) + _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((3,), 0.2, dtype="float32")) + _test_prelu(np.random.uniform(-5, 5, size=(1, 32, 32, 3)).astype("float32"), np.full((1, 1, 3), 0.2, dtype="float32")) ####################################################################### # Fully Connected