Skip to content

Commit

Permalink
[Relay][Frontend][TFLite] Add parser support for squared difference
Browse files Browse the repository at this point in the history
  • Loading branch information
wyc-ruiker committed Jan 8, 2020
1 parent bc0274d commit b0934bc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def __init__(self, model, subgraph, exp_tab):
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
'SQUARED_DIFFERENCE': self.convert_squared_difference,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -628,6 +629,15 @@ def convert_greater(self, op):
'TFlite quantized greater operator is not supported yet.')
return self._convert_elemwise(_op.greater, op)

def convert_squared_difference(self, op):
# Check if the input tensor is quantized, call QNN op
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized greater operator is not supported yet.')
difference = self._convert_elemwise(_op.subtract, op)
out = _op.multiply(difference, difference)
return out

def convert_zeros_like(self, op):
"""Convert TFLite ZEROS LIKE"""
try:
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,14 @@ def _test_greater(data):
""" One iteration of greater """
return _test_elemwise(math_ops.greater, data)

#######################################################################
# Squared_difference
# -------

def _test_squared_difference(data):
""" One iteration of squared difference """
return _test_elemwise(math_ops.squared_difference, data)

def _test_forward_elemwise(testop):
""" Elewise"""
testop([np.arange(6.0, dtype=np.float32).reshape((2, 1, 1, 3)),
Expand Down Expand Up @@ -765,6 +773,7 @@ def test_all_elemwise():
_test_forward_elemwise(_test_maximum)
_test_forward_elemwise(_test_minimum)
_test_forward_elemwise(_test_greater)
_test_forward_elemwise(_test_squared_difference)

#######################################################################
# Zeros like
Expand Down

0 comments on commit b0934bc

Please sign in to comment.