Skip to content

Commit

Permalink
[PYTORCH]Minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed May 28, 2020
1 parent a072da0 commit 7bd4309
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .common import get_relay_op
from .common import infer_shape as _infer_shape
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import infer_type as _infer_type
from ..prelude import Prelude, StaticTensorArrayOps

Expand Down Expand Up @@ -152,19 +153,33 @@ def _impl(inputs, input_types):

def _arange():
def _impl(inputs, input_types):
def _get_value(val, dtype):
if isinstance(val, _expr.Expr):
return _op.cast(val, _convert_data_type(dtype))
return _create_typed_const(val, dtype)

def _get_type(val, inp_type):
if isinstance(val, _expr.Expr):
dtype = str(_infer_type(val).checked_type)
return dtype if dtype != "float32" else "float"
return inp_type

if len(inputs) == 5:
dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
start = _create_typed_const(0, dtype)
stop = _create_typed_const(inputs[0], dtype)
step = _create_typed_const(1, dtype)
dtype0 = _get_type(inputs[0], input_types[0])
dtype = "float" if dtype0 == "float" else _convert_dtype_value(inputs[1])
start = _get_value(0, dtype)
stop = _get_value(inputs[0], dtype)
step = _get_value(1, dtype)
elif len(inputs) == 7:
dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(inputs[0], dtype)
stop = _create_typed_const(inputs[1], dtype)
step = _create_typed_const(inputs[2], dtype)
types = [_get_type(inputs[i], input_types[i]) for i in range(3)]
dtype = "float" if "float" in types else _convert_dtype_value(inputs[3])
start = _get_value(inputs[0], dtype)
stop = _get_value(inputs[1], dtype)
step = _get_value(inputs[2], dtype)
else:
msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
raise AssertionError(msg)

return _op.transform.arange(start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -235,12 +250,18 @@ def _impl(inputs, input_types):

begin = [0] * len(end)
dim = int(inputs[1])
begin[dim] = int(inputs[2])
if isinstance(inputs[2], _expr.Call):
begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int))
else:
begin[dim] = int(inputs[2])

if isinstance(inputs[3], str) and inputs[3].isdigit():
end[dim] = min(end[dim], int(inputs[3]))
else:
end[dim] = inputs[3]
if isinstance(inputs[3], _expr.Call):
end[dim] = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
else:
end[dim] = inputs[3]

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data, begin, end, strides)
Expand Down Expand Up @@ -997,7 +1018,10 @@ def _impl(inputs, input_types):
def _numtotensor():
def _impl(inputs, input_types):
val = inputs[0]
dtype = type(val)
dtype = input_types[0]

if isinstance(val, _expr.Expr):
return val

if isinstance(val, tvm.tir.IntImm):
val = val.__int__()
Expand All @@ -1019,16 +1043,22 @@ def _impl(inputs, input_types):
data = inputs[0]

if len(inputs) == 3:
new_shape = [inputs[1], _infer_shape(inputs[2])[0]]
shape_inp = [inputs[1], _infer_shape(inputs[2])[0]]
else:
if isinstance(inputs[1], list):
new_shape = inputs[1]
shape_inp = inputs[1]
else:
new_shape = _infer_shape(inputs[1])
shape_inp = _infer_shape(inputs[1])
new_shape = shape_inp
for i, shape in enumerate(shape_inp):
if isinstance(shape, _expr.Expr):
val = _infer_value_simulated(shape, {})
new_shape[i] = np.asscalar(val.asnumpy())

return _op.transform.reshape(data, new_shape)
return _impl


def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down

0 comments on commit 7bd4309

Please sign in to comment.