Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][TensorFlow] Add support for SquaredDifference #3930

Merged
merged 5 commits into from
Sep 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def _impl(inputs, attr, params):

# reshape result back to n-dimensional
if len(orig_shape_x) > 3:
final_shape = attr['_output_shapes'][0]
final_shape = list(orig_shape_x)
final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2]
final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1]
ret = _op.reshape(ret, newshape=final_shape)

return ret
Expand Down Expand Up @@ -1227,6 +1229,12 @@ def _impl(inputs, attr, params):
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl

def _squared_difference():
def _impl(inputs, attr, params):
difference = _op.subtract(inputs[0], inputs[1])
return _op.multiply(difference, difference)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1334,6 +1342,7 @@ def _impl(inputs, attr, params):
'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
'Square' : _square(),
'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,16 @@ def test_forward_erf():
tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')

def test_forward_squared_difference():
ishape = (1, 3, 10, 14)
inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1")
in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2")
out = tf.math.squared_difference(in1, in2)
compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name)

def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
Expand Down Expand Up @@ -2253,6 +2263,7 @@ def test_forward_one_hot():
test_forward_bias_add()
test_forward_zeros_like()
test_forward_erf()
test_forward_squared_difference()

# Reductions
test_forward_argminmax()
Expand Down