Skip to content

Commit

Permalink
[Relay][Frontend][Tensorflow] Fix type assignment for operator 'tf.ra…
Browse files Browse the repository at this point in the history
…nge' (#4294)
  • Loading branch information
cchung100m authored and zhiics committed Nov 12, 2019
1 parent 6252145 commit d184d2f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,14 +1075,15 @@ def _impl(inputs, attr, params):

return _impl


def _range():
def _impl(inputs, attr, params):
start = _get_param(params, inputs[0])[0]
limit = _get_param(params, inputs[1])[0] \
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
else params.pop('Rank').asnumpy()[0]
delta = _get_param(params, inputs[2])[0]
dtype = attr['dtype'].name if 'dtype' in attr else "int32"
dtype = attr['Tidx'].name if 'Tidx' in attr else str(start.dtype)
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
Expand All @@ -1092,6 +1093,7 @@ def _impl(inputs, attr, params):
'dtype': dtype})([], attr)
return _impl


def _elu():
def _impl(inputs, attr, params):
dtype = attr['T'].name
Expand Down Expand Up @@ -1202,7 +1204,7 @@ def _impl(inputs, attr, params):
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
raise tvm.error.OpAttributeUnImplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
Expand Down
5 changes: 5 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,11 @@ def test_forward_range():
tf.range(1, 18, 3, name="range")
compare_tf_with_tvm([], [], 'range:0')

"""test type assignment for operator Range"""
tf.reset_default_graph()
tf.range(1, 256 + 1, 1, dtype=tf.float32)
compare_tf_with_tvm([], [], 'range:0')

#######################################################################
# Pad
# ---
Expand Down

0 comments on commit d184d2f

Please sign in to comment.