Skip to content

Commit

Permalink
[TFLite] Implemented MATRIX_DIAG Operator for TFLite. (apache#6397)
Browse files Browse the repository at this point in the history
* Added implementation for MATRIX_DIAG Operator.
* Added tests for MATRIX_DIAG Operator.
  • Loading branch information
jainris authored and kevinthesun committed Sep 18, 2020
1 parent e1468f2 commit a94aca4
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, model, subgraph, exp_tab):
'LOGICAL_NOT': self.convert_logical_not,
'LOGICAL_OR': self.convert_logical_or,
'LOGISTIC': self.convert_logistic,
'MATRIX_DIAG': self.convert_matrix_diag,
'MATRIX_SET_DIAG': self.convert_matrix_set_diag,
'MAX_POOL_2D': self.convert_max_pool2d,
'MAXIMUM': self.convert_maximum,
Expand Down Expand Up @@ -3020,6 +3021,30 @@ def convert_matrix_set_diag(self, op):
out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out

def convert_matrix_diag(self, op):
"""Convert TFLite MATRIX_DIAG"""
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensor's length should be 1"

diagonal = input_tensors[0]

if diagonal.qnn_params:
# Check that diagonal and output tensor have same qnn params.
output_tensors = self.get_output_tensors(op)
assert self.has_same_qnn_params(diagonal, output_tensors[0]), \
"TFLite MATRIX_DIAG requires diagonal and output tensors' \
scale and zero points to be equal"

shape = diagonal.tensor.ShapeAsNumpy()
shape = np.append(shape, shape[-1])
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

input_expr = _op.zeros(tuple(shape), dtype)
diagonal_expr = self.get_tensor_expr(diagonal)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out


def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
Expand Down
28 changes: 28 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2760,6 +2760,33 @@ def test_forward_matrix_set_diag():
_test_matrix_set_diag((4, 4, 2), np.uint8, quantized=True)


#######################################################################
# MATRIX_DIAG
# -----------

def _test_matrix_diag(diagonal_shape, dtype):
""" One iteration of MATRIX_DIAG """
with tf.Graph().as_default():
diagonal = np.random.uniform(0, 100, diagonal_shape).astype(dtype)
in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal")

out = array_ops.matrix_diag(in_diagonal)

compare_tflite_with_tvm(
[diagonal],
["diagonal"],
[in_diagonal],
[out],
experimental_new_converter=True)

def test_forward_matrix_diag():
""" MATRIX_DIAG """
for dtype in [np.float32, np.int32]:
_test_matrix_diag((4), dtype)
_test_matrix_diag((5, 4, 3), dtype)
_test_matrix_diag((2, 3), dtype)


#######################################################################
# Custom Operators
# ----------------
Expand Down Expand Up @@ -3240,6 +3267,7 @@ def test_forward_mediapipe_hand_landmark():
test_forward_expand_dims()
test_forward_reverse_v2()
test_forward_matrix_set_diag()
test_forward_matrix_diag()

# NN
test_forward_convolution()
Expand Down

0 comments on commit a94aca4

Please sign in to comment.