diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 6e032b1efda8..17908b229bd7 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -114,6 +114,7 @@ def __init__(self, model, subgraph, exp_tab): 'MUL': self.convert_mul, 'NEG': self.convert_neg, 'NOT_EQUAL': self.convert_not_equal, + 'ONE_HOT': self.convert_one_hot, 'PACK': self.convert_pack, 'PAD': self.convert_pad, 'PADV2': self.convert_pad, @@ -2886,6 +2887,56 @@ def convert_detection_postprocess(self, op): ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) return ret + def convert_one_hot(self, op): + """Convert TFLite ONE_HOT""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.OneHotOptions import OneHotOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 4, "Input tensor's length should be 4" + + # Ensuring input isn't quantized + assert all(not i.qnn_params for i in input_tensors), \ + "Quantized input is not expected." + + # TFlite ONE_HOT requires both on_value + # and off_value, making dtype redundant. + indices = input_tensors[0] + depth = input_tensors[1] + on_value = input_tensors[2] + off_value = input_tensors[3] + + assert on_value.tensor.Type() == off_value.tensor.Type(), \ + "on_value and off_value should be the same type" + + # Getting relay expr + indices_expr = self.get_expr(indices.tensor_idx) + on_value_expr = self.get_expr(on_value.tensor_idx) + off_value_expr = self.get_expr(off_value.tensor_idx) + + # Getting depth value + depth = self.get_tensor_value(depth) + if isinstance(depth, np.ndarray): + depth = int(depth) + + # Getting Axis from Option (Attributes) + assert op.BuiltinOptionsType() == BuiltinOptions.OneHotOptions + op_options = op.BuiltinOptions() + one_hot_options = OneHotOptions() + one_hot_options.Init(op_options.Bytes, op_options.Pos) + axis = one_hot_options.Axis() + + # Setting dtype + dtype = self.get_tensor_type_str(on_value.tensor.Type()) + + out = _op.one_hot(indices_expr, on_value_expr, off_value_expr, depth, axis, dtype) + + return out + + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 603eb1169624..304ef60a7abf 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2015,6 +2015,35 @@ def test_forward_padv2(): np.uint8(10)], quantized=True) +####################################################################### +# ONE_HOT +# ------- + +def _test_one_hot(indices, depth, on_value, off_value, axis = None): + """ One iteration of One_Hot """ + with tf.Graph().as_default(): + in_indices = tf.placeholder(dtype=indices.dtype, shape=indices.shape, name="indices") + in_depth = ops.convert_to_tensor(depth, dtype=depth.dtype) + in_on_value = tf.placeholder(dtype=on_value.dtype, shape=on_value.shape, name="on_value") + in_off_value = tf.placeholder(dtype=off_value.dtype, shape=off_value.shape, name="off_value") + if axis is not None: + out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value, axis=axis) + else: + out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value) + compare_tflite_with_tvm( + [indices, on_value, off_value], + ["indices", "on_value", "off_value"], + [in_indices, in_on_value, in_off_value], + [out]) + +def test_forward_one_hot(): + """ One_Hot """ + _test_one_hot(np.int32(2), np.int32(8), np.int32(1), np.int32(0)) + _test_one_hot(np.int32(4), np.int32(8), np.float32(1), np.float32(0)) + _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1)) + _test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1), axis=0) + + ####################################################################### # Pack # ----