Skip to content

Commit

Permalink
[Relay][Frontend][TF] Fix transpose when axes is not a param (apache#…
Browse files Browse the repository at this point in the history
…4327)

* [Relay][Frontend][TF] Use _infer_value_simulated when axes is not a const to Transpose

* uncomment tests

* dummy change to retrigger ci
  • Loading branch information
soiferj authored and Xingyu Zhou committed Nov 15, 2019
1 parent fd1c187 commit 2a012a0
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,8 +1039,8 @@ def _impl(inputs, attr, params):
# otherwise its value is get from params
try:
axes = _get_list_param(params, inputs[1])
except (IndexError, KeyError):
axes = None
except (IndexError, KeyError, AttributeError):
axes = _infer_value_simulated(inputs[1], params).asnumpy()
return _op.transpose(inputs[0], axes=axes)
return _impl

Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,22 @@ def _test_forward_transpose(ishape, axes=None):

compare_tf_with_tvm(data, 'transpose_data:0', 'transpose:0')

def _test_forward_tranapose_axes_input(ishape, axes):
data = np.random.uniform(size=ishape).astype(np.float32)
axes_np = np.array(axes).astype(np.int32)

with tf.Graph().as_default():
in1 = tf.placeholder(
shape=data.shape, dtype=data.dtype, name="transpose_data")

const1 = tf.constant(axes_np, dtype=tf.int32)

# make axes an input to tf.transpose, but not an input to the graph,
# so it can be extracted with infer_value_simulated
axes = tf.reverse(const1, axis=[-1])
tf.transpose(in1, axes)

compare_tf_with_tvm([data], ['transpose_data:0'], 'transpose:0')

def test_forward_transpose():
_test_forward_transpose((2, 3, 4), (1, 2, 0))
Expand All @@ -2122,6 +2138,8 @@ def test_forward_transpose():
_test_forward_transpose((2, 3, 4), (1, 2, 0))
_test_forward_transpose((2, 3, 4), (0, 1, 2))
_test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
_test_forward_tranapose_axes_input((2, 3, 4), (1, 2, 0))
_test_forward_tranapose_axes_input((2, 3, 4, 5), (3, 0, 1, 2))


def test_forward_ceil():
Expand Down

0 comments on commit 2a012a0

Please sign in to comment.