Skip to content

Commit

Permalink
[Frontend][TFLite] Fix fully_connected converter when batch size is n…
Browse files Browse the repository at this point in the history
…ot 1 (#6038)

* Fix fully_connected when batched

* Remove unused variable
  • Loading branch information
Trevor Morris authored Jul 14, 2020
1 parent 712c82f commit 99c52f3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
10 changes: 1 addition & 9 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,7 +1706,6 @@ def convert_fully_connected(self, op):
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()

# Weight should have only 2 dimensions(TFLite convention)
Expand All @@ -1719,14 +1718,7 @@ def convert_fully_connected(self, op):
# Dense expected Input shape: [batch_size, n_units]
# Dense expected Weight shape: [out_dim, n_units]
# Dense output shape: [batch_size, out_dim]
# So it is evident that input shape: [batch_size = input_size / n_units, n_units]
input_size = 1
for _, shape in enumerate(input_tensor_shape):
input_size *= shape

# First get the batch size
batch_size = int(input_size / weight_tensor_shape[1])
target_shape = tuple((batch_size, weight_tensor_shape[1]))
target_shape = tuple((-1, weight_tensor_shape[1]))
in_expr = self.get_expr(input_tensor_idx)
in_expr = _op.reshape(in_expr, target_shape)

Expand Down
21 changes: 19 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ def run_tflite_graph(tflite_model_buf, input_data):
input_data = convert_to_list(input_data)

interpreter = interpreter_wrapper.Interpreter(model_content=tflite_model_buf)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

for i in range(len(input_details)):
interpreter.resize_tensor_input(input_details[i]['index'], input_data[i].shape)
interpreter.allocate_tensors()

# set input
assert len(input_data) == len(input_details)
for i in range(len(input_details)):
Expand Down Expand Up @@ -2548,6 +2550,20 @@ def test_forward_inception_v4_net():
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

def test_forward_inception_v4_net_batched():
"""Test the Inception V4 TF Lite model."""
# InceptionV4
tflite_model_file = tf_testing.get_workload_official(
"https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz",
"inception_v4.tflite")
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
data = np.random.uniform(size=(4, 299, 299, 3)).astype('float32')
tflite_output = run_tflite_graph(tflite_model_buf, data)
tvm_output = run_tvm_graph(tflite_model_buf, data, 'input')
tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
rtol=1e-5, atol=1e-5)

def test_forward_qnn_inception_v1_net():
"""Test the Quantized TFLite Inception model."""
# InceptionV1
Expand Down Expand Up @@ -2914,6 +2930,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_mobilenet_v3()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_inception_v4_net_batched()
test_forward_coco_ssd_mobilenet_v1()
test_forward_mediapipe_hand_landmark()

Expand Down

0 comments on commit 99c52f3

Please sign in to comment.