Skip to content

Commit

Permalink
[Torch] Add cast to double, fix flatten conversion (#6357)
Browse files Browse the repository at this point in the history
* support cast to double and fix flatten conversion

* also support batch flatten, add test

* add flatten test

* clean up
  • Loading branch information
masahi authored Aug 29, 2020
1 parent d9450f8 commit 2d752d2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
15 changes: 14 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,12 +996,23 @@ def _impl(inputs, input_types):
return _op.transform.transpose(data, axes)
return _impl


def _flatten():
def _impl(inputs, input_types):
data = inputs[0]
return _op.nn.batch_flatten(data)
start_dim = inputs[1] if len(inputs) > 0 else 0
end_dim = inputs[2] if len(inputs) > 1 else -1

if start_dim == 0 and end_dim == -1:
return _op.transform.reshape(data, (-1,))
if start_dim == 1 and end_dim == -1:
return _op.nn.batch_flatten(data)

raise NotImplementedError("Only support 1d flatten or batch flatten")

return _impl


def _dense():
def _impl(inputs, input_types):
use_bias = isinstance(inputs[0], _expr.Expr)
Expand Down Expand Up @@ -1509,11 +1520,13 @@ def _impl(inputs, input_types):
# this happens when converting upsampling with scale factor
cast_func = {
6: float,
7: float,
3: int,
4: int
}
cast_func_expr = {
6: lambda x: _op.cast(x, "float32"),
7: lambda x: _op.cast(x, "float64"),
3: lambda x: _op.cast(x, "int32"),
4: lambda x: _op.cast(x, "int64"),
}
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,21 @@ def forward(self, *args):
verify_model(Reshape1().float().eval(), input_data=input_data)
verify_model(Reshape2().float().eval(), input_data=input_data)


def test_flatten():
class Flatten(Module):
def forward(self, x):
return torch.flatten(x)

class BatchFlatten(Module):
def forward(self, x):
return torch.flatten(x, start_dim=1)

inp = torch.rand((5, 2, 2))
verify_model(Flatten(), input_data=inp)
verify_model(BatchFlatten(), input_data=inp)


def test_forward_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -1311,12 +1326,17 @@ class ToLong(Module):
def forward(self, x):
return x.long()

class ToDouble(Module):
def forward(self, x):
return x.double()

verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32)))
verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int))
verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int))
verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32)))
verify_model(ToInt().eval(), torch.tensor(0.8))
verify_model(ToLong().eval(), torch.tensor(0.8))
verify_model(ToDouble().eval(), torch.tensor(0.8))


def test_adaptive_pool3d():
Expand Down Expand Up @@ -2901,6 +2921,7 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_upsample3d()
test_forward_nms()
test_to()
test_flatten()
test_type_as()
test_forward_functional_pad()
test_forward_zero_pad2d()
Expand Down

0 comments on commit 2d752d2

Please sign in to comment.