Skip to content

Commit

Permalink
[TFLite] Implemented ONE_HOT Operator for TFLite (apache#6223)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainris authored and wjliu1998 committed Aug 13, 2020
1 parent 22fdedd commit 7afc712
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
51 changes: 51 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, model, subgraph, exp_tab):
'MUL': self.convert_mul,
'NEG': self.convert_neg,
'NOT_EQUAL': self.convert_not_equal,
'ONE_HOT': self.convert_one_hot,
'PACK': self.convert_pack,
'PAD': self.convert_pad,
'PADV2': self.convert_pad,
Expand Down Expand Up @@ -2903,6 +2904,56 @@ def convert_detection_postprocess(self, op):
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret

def convert_one_hot(self, op):
"""Convert TFLite ONE_HOT"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.OneHotOptions import OneHotOptions
except ImportError:
raise ImportError("The tflite package must be installed")

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 4, "Input tensor's length should be 4"

# Ensuring input isn't quantized
assert all(not i.qnn_params for i in input_tensors), \
"Quantized input is not expected."

# TFlite ONE_HOT requires both on_value
# and off_value, making dtype redundant.
indices = input_tensors[0]
depth = input_tensors[1]
on_value = input_tensors[2]
off_value = input_tensors[3]

assert on_value.tensor.Type() == off_value.tensor.Type(), \
"on_value and off_value should be the same type"

# Getting relay expr
indices_expr = self.get_expr(indices.tensor_idx)
on_value_expr = self.get_expr(on_value.tensor_idx)
off_value_expr = self.get_expr(off_value.tensor_idx)

# Getting depth value
depth = self.get_tensor_value(depth)
if isinstance(depth, np.ndarray):
depth = int(depth)

# Getting Axis from Option (Attributes)
assert op.BuiltinOptionsType() == BuiltinOptions.OneHotOptions
op_options = op.BuiltinOptions()
one_hot_options = OneHotOptions()
one_hot_options.Init(op_options.Bytes, op_options.Pos)
axis = one_hot_options.Axis()

# Setting dtype
dtype = self.get_tensor_type_str(on_value.tensor.Type())

out = _op.one_hot(indices_expr, on_value_expr, off_value_expr, depth, axis, dtype)

return out


def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down
29 changes: 29 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2030,6 +2030,35 @@ def test_forward_padv2():
np.uint8(10)], quantized=True)


#######################################################################
# ONE_HOT
# -------

def _test_one_hot(indices, depth, on_value, off_value, axis = None):
""" One iteration of One_Hot """
with tf.Graph().as_default():
in_indices = tf.placeholder(dtype=indices.dtype, shape=indices.shape, name="indices")
in_depth = ops.convert_to_tensor(depth, dtype=depth.dtype)
in_on_value = tf.placeholder(dtype=on_value.dtype, shape=on_value.shape, name="on_value")
in_off_value = tf.placeholder(dtype=off_value.dtype, shape=off_value.shape, name="off_value")
if axis is not None:
out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value, axis=axis)
else:
out = array_ops.one_hot(in_indices, in_depth, in_on_value, in_off_value)
compare_tflite_with_tvm(
[indices, on_value, off_value],
["indices", "on_value", "off_value"],
[in_indices, in_on_value, in_off_value],
[out])

def test_forward_one_hot():
""" One_Hot """
_test_one_hot(np.int32(2), np.int32(8), np.int32(1), np.int32(0))
_test_one_hot(np.int32(4), np.int32(8), np.float32(1), np.float32(0))
_test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1))
_test_one_hot(np.array([1, 2, 3], dtype=np.int32), np.int32(8), np.int32(3), np.int32(-1), axis=0)


#######################################################################
# Pack
# ----
Expand Down

0 comments on commit 7afc712

Please sign in to comment.