Skip to content

Commit

Permalink
[Relay][TensorFlow] Add support for SquaredDifference (apache#3930)
Browse files Browse the repository at this point in the history
* Add support for SquaredDifference and StopGradient; minor fix in BatchMatMul

* Remove stopgradient change

* Resolve PR comment

* Dummy change to retrigger CI

* dummy change to retrigger CI
  • Loading branch information
soiferj authored and wweic committed Sep 16, 2019
1 parent 4b60ceb commit b708b39
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ def _impl(inputs, attr, params):

# reshape result back to n-dimensional
if len(orig_shape_x) > 3:
final_shape = attr['_output_shapes'][0]
final_shape = list(orig_shape_x)
final_shape[-2] = orig_shape_x[-1] if adj_x else orig_shape_x[-2]
final_shape[-1] = orig_shape_y[-2] if adj_y else orig_shape_y[-1]
ret = _op.reshape(ret, newshape=final_shape)

return ret
Expand Down Expand Up @@ -1227,6 +1229,12 @@ def _impl(inputs, attr, params):
extras={'depth' : depth, 'dtype' : dtype})(new_inputs, attr)
return _impl

def _squared_difference():
def _impl(inputs, attr, params):
difference = _op.subtract(inputs[0], inputs[1])
return _op.multiply(difference, difference)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1334,6 +1342,7 @@ def _impl(inputs, attr, params):
'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
'Square' : _square(),
'SquaredDifference' : _squared_difference(),
'Squeeze' : _squeeze(),
'StridedSlice' : _stridedSlice(),
'Sub' : _elemwise('subtract'),
Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,6 +1852,16 @@ def test_forward_erf():
tf.math.erf(in1)
compare_tf_with_tvm(inp_array, 'Placeholder:0', 'Erf:0')

def test_forward_squared_difference():
ishape = (1, 3, 10, 14)
inp_array_a = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
inp_array_b = np.random.uniform(-5, 5, size=ishape).astype(np.float32)
with tf.Graph().as_default():
in1 = tf.placeholder(shape=inp_array_a.shape, dtype=inp_array_a.dtype, name="in1")
in2 = tf.placeholder(shape=inp_array_b.shape, dtype=inp_array_b.dtype, name="in2")
out = tf.math.squared_difference(in1, in2)
compare_tf_with_tvm([inp_array_a, inp_array_b], [in1.name, in2.name], out.name)

def _test_forward_reverse_v2(in_shape, axis, dtype):
np_data = np.random.uniform(-10, 10, size=in_shape).astype(dtype)
tf.reset_default_graph()
Expand Down Expand Up @@ -2253,6 +2263,7 @@ def test_forward_one_hot():
test_forward_bias_add()
test_forward_zeros_like()
test_forward_erf()
test_forward_squared_difference()

# Reductions
test_forward_argminmax()
Expand Down

0 comments on commit b708b39

Please sign in to comment.