diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6e4a62dae359..59ba9f45c824 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -107,6 +107,7 @@ def __init__(self, model, subgraph, exp_tab): 'LOGICAL_NOT': self.convert_logical_not, 'LOGICAL_OR': self.convert_logical_or, 'LOGISTIC': self.convert_logistic, + 'MATRIX_DIAG': self.convert_matrix_diag, 'MATRIX_SET_DIAG': self.convert_matrix_set_diag, 'MAX_POOL_2D': self.convert_max_pool2d, 'MAXIMUM': self.convert_maximum, @@ -3020,6 +3021,30 @@ def convert_matrix_set_diag(self, op): out = _op.matrix_set_diag(input_expr, diagonal_expr) return out + def convert_matrix_diag(self, op): + """Convert TFLite MATRIX_DIAG""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensor's length should be 1" + + diagonal = input_tensors[0] + + if diagonal.qnn_params: + # Check that diagonal and output tensor have same qnn params. + output_tensors = self.get_output_tensors(op) + assert self.has_same_qnn_params(diagonal, output_tensors[0]), \ + "TFLite MATRIX_DIAG requires diagonal and output tensors' \ + scale and zero points to be equal" + + shape = diagonal.tensor.ShapeAsNumpy() + shape = np.append(shape, shape[-1]) + dtype = self.get_tensor_type_str(diagonal.tensor.Type()) + + input_expr = _op.zeros(tuple(shape), dtype) + diagonal_expr = self.get_tensor_expr(diagonal) + + out = _op.matrix_set_diag(input_expr, diagonal_expr) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 3577de3856f3..89296a63e5c0 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2760,6 +2760,33 @@ def test_forward_matrix_set_diag(): _test_matrix_set_diag((4, 4, 2), np.uint8, quantized=True) +####################################################################### +# MATRIX_DIAG +# ----------- + +def _test_matrix_diag(diagonal_shape, dtype): + """ One iteration of MATRIX_DIAG """ + with tf.Graph().as_default(): + diagonal = np.random.uniform(0, 100, diagonal_shape).astype(dtype) + in_diagonal = tf.placeholder(dtype=diagonal.dtype, shape=diagonal.shape, name="diagonal") + + out = array_ops.matrix_diag(in_diagonal) + + compare_tflite_with_tvm( + [diagonal], + ["diagonal"], + [in_diagonal], + [out], + experimental_new_converter=True) + +def test_forward_matrix_diag(): + """ MATRIX_DIAG """ + for dtype in [np.float32, np.int32]: + _test_matrix_diag((4), dtype) + _test_matrix_diag((5, 4, 3), dtype) + _test_matrix_diag((2, 3), dtype) + + ####################################################################### # Custom Operators # ---------------- @@ -3240,6 +3267,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_expand_dims() test_forward_reverse_v2() test_forward_matrix_set_diag() + test_forward_matrix_diag() # NN test_forward_convolution()