Skip to content

Commit

Permalink
Fix scalar detection using numpy.isscalar
Browse files Browse the repository at this point in the history
and address other review comments. Thank you @siju-samuel
  • Loading branch information
t-vi committed Jun 18, 2020
1 parent 509f520 commit 94b3952
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 4 additions & 5 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,14 +1752,12 @@ def _pytorch_result_type(dtypes, non_tensor_inputs):

def _pytorch_promote_types(inputs, dtypes):
"""This promotes TVM inputs with TVM dtypes passed like PyTorch would"""
tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes)
if not isinstance(inp, (float, int, bool))]
non_tensor_inputs = [inp for inp in inputs
if isinstance(inp, (float, int, bool))]
tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)]
non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)]
result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs)
results = []
for inp, dt in zip(inputs, dtypes):
if isinstance(inp, (float, int, bool)):
if np.isscalar(inp):
results.append(_expr.const(inp, dtype=result_type))
elif dt == result_type:
results.append(inp)
Expand Down Expand Up @@ -2028,6 +2026,7 @@ def _run_jit_passes(graph):


def _is_int_seq(seq):
# TODO (t-vi): handle non-int constants? (like numpy.intXX)
return len(seq) > 0 and all([isinstance(i, int) for i in seq])


Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,6 +2546,8 @@ def test_forward_pretrained_bert_base_uncased():


if __name__ == "__main__":
test_forward_traced_function()
test_forward_dtypes()
# Single operator tests
test_forward_add()
test_forward_subtract()
Expand Down

0 comments on commit 94b3952

Please sign in to comment.