Skip to content

Commit

Permalink
Constant input attr added to fully connected operation in TFLite fron…
Browse files Browse the repository at this point in the history
…tend (apache#6228)

* Constant input attr added to fully connected operation

An ability to handle constant input attr added to fully connected operation
Unit tests amended.

* renamed wrap_input to const_input

* removed extra spaces
  • Loading branch information
d-smirnov authored and Trevor Morris committed Aug 26, 2020
1 parent d2b8ade commit 112adb5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
5 changes: 2 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,10 +1708,9 @@ def convert_fully_connected(self, op):
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) >= 2, "input tensors length should be >= 2"
assert len(input_tensors) in (2, 3), "input tensors length should be two or three"

input_tensor = input_tensors[0]
input_tensor_idx = input_tensor.tensor_idx
weight_tensor = input_tensors[1]

output_tensors = self.get_output_tensors(op)
Expand All @@ -1733,7 +1732,7 @@ def convert_fully_connected(self, op):
# Dense expected Weight shape: [out_dim, n_units]
# Dense output shape: [batch_size, out_dim]
target_shape = tuple((-1, weight_tensor_shape[1]))
in_expr = self.get_expr(input_tensor_idx)
in_expr = self.get_tensor_expr(input_tensor)
in_expr = _op.reshape(in_expr, target_shape)

#TODO: Change the output shape calculation based on keep_dim option
Expand Down
48 changes: 27 additions & 21 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2576,25 +2576,27 @@ def test_forward_sparse_to_dense():
# Fully Connected
# ---------------

def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
def _test_fully_connected(tensor_in_sizes, const_input, filter_in_sizes, bias_in_size=None):
""" One iteration of fully connected """

total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
total_size_1 = np.prod(tensor_in_sizes)
total_size_2 = np.prod(filter_in_sizes)

assert int(total_size_1 / tensor_in_sizes[0]) == filter_in_sizes[0], \
"input size and filter size are mismatched"

# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = np.arange(1, total_size_1 + 1, dtype=np.float32)
filter_array = np.arange(1, total_size_2 + 1, dtype=np.float32)

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
in_name="input"
in_data = constant_op.constant(data_array, shape=tensor_in_sizes, dtype=np.float32, name=in_name) \
if const_input \
else array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32, name=in_name)

in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype=np.float32)

# reshape N H W C into N H*W*C
in_data_reshape = array_ops.reshape(in_data, [tensor_in_sizes[0], -1])
Expand All @@ -2604,20 +2606,24 @@ def _test_fully_connected(tensor_in_sizes, filter_in_sizes, bias_in_size=None):
# if we have bias
if bias_in_size:
assert bias_in_size[0] == filter_in_sizes[1], "bias and filter size are mismatched"
bias_array = [f * 1.0 for f in range(1, bias_in_size[0] + 1)]
in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype='float32')
bias_array = np.arange(1, bias_in_size[0] + 1, dtype=np.float32)
in_bias = constant_op.constant(bias_array, shape=bias_in_size, dtype=np.float32)
out = nn_ops.bias_add(out, in_bias)

data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])
data_array = np.reshape(data_array, tensor_in_sizes).astype(np.float32)
compare_tflite_with_tvm(data_array,
[] if const_input else in_data.name,
[in_data],
[out])


def test_forward_fully_connected():
""" Fully Connected """
_test_fully_connected([1, 1, 1, 150], [150, 100])
_test_fully_connected([1, 1, 1, 150], [150, 100], [100])
_test_fully_connected([5, 1, 1, 150], [150, 100])
_test_fully_connected([5, 1, 1, 150], [150, 100], [100])
for const_input in [False, True]:
_test_fully_connected([1, 1, 1, 150], const_input, [150, 100])
_test_fully_connected([1, 1, 1, 150], const_input, [150, 100], [100])
_test_fully_connected([5, 1, 1, 150], const_input, [150, 100])
_test_fully_connected([5, 1, 1, 150], const_input, [150, 100], [100])


#######################################################################
Expand Down

0 comments on commit 112adb5

Please sign in to comment.