Skip to content

Commit

Permalink
[Relay][Frontend][TFLite] Add parser support for squared difference (a…
Browse files Browse the repository at this point in the history
…pache#4652)

* [Relay][Frontend][TFLite] Add parser support for squared difference

* fix some error

* fix exp_type

* add comment
  • Loading branch information
wyc-ruiker authored and alexwong committed Feb 28, 2020
1 parent c96e256 commit b2e35ab
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(self, model, subgraph, exp_tab):
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -735,6 +736,17 @@ def convert_greater(self, op):
'TFlite quantized greater operator is not supported yet.')
return self._convert_elemwise(_op.greater, op)

def convert_squared_difference(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized squared difference operator is not supported yet.')
difference = self._convert_elemwise(_op.subtract, op)
# _convert_elemwise has guaranteed only have one output tensor
exp_type = self.get_tensor_type_str(self.get_output_tensors(op)[0].tensor.Type())
out = _op.power(difference, relay.const(2, exp_type))
return out

def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
try:
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,14 @@ def _test_greater(data):
""" One iteration of greater """
return _test_elemwise(math_ops.greater, data)

#######################################################################
# Squared_difference
# ------------------

def _test_squared_difference(data):
""" One iteration of squared difference """
return _test_elemwise(math_ops.squared_difference, data)

def _test_forward_elemwise(testop):
""" Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
Expand Down Expand Up @@ -906,6 +914,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_maximum)
_test_forward_elemwise(_test_minimum)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise(_test_squared_difference)

#######################################################################
# Zeros like
Expand Down

0 comments on commit b2e35ab

Please sign in to comment.