diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 21cf9c3a1b97..108d1d8dc1ad 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -996,12 +996,23 @@ def _impl(inputs, input_types): return _op.transform.transpose(data, axes) return _impl + def _flatten(): def _impl(inputs, input_types): data = inputs[0] - return _op.nn.batch_flatten(data) + start_dim = inputs[1] if len(inputs) > 0 else 0 + end_dim = inputs[2] if len(inputs) > 1 else -1 + + if start_dim == 0 and end_dim == -1: + return _op.transform.reshape(data, (-1,)) + if start_dim == 1 and end_dim == -1: + return _op.nn.batch_flatten(data) + + raise NotImplementedError("Only support 1d flatten or batch flatten") + return _impl + def _dense(): def _impl(inputs, input_types): use_bias = isinstance(inputs[0], _expr.Expr) @@ -1509,11 +1520,13 @@ def _impl(inputs, input_types): # this happens when converting upsampling with scale factor cast_func = { 6: float, + 7: float, 3: int, 4: int } cast_func_expr = { 6: lambda x: _op.cast(x, "float32"), + 7: lambda x: _op.cast(x, "float64"), 3: lambda x: _op.cast(x, "int32"), 4: lambda x: _op.cast(x, "int64"), } diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 946712df5086..2e54ac4b4719 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -881,6 +881,21 @@ def forward(self, *args): verify_model(Reshape1().float().eval(), input_data=input_data) verify_model(Reshape2().float().eval(), input_data=input_data) + +def test_flatten(): + class Flatten(Module): + def forward(self, x): + return torch.flatten(x) + + class BatchFlatten(Module): + def forward(self, x): + return torch.flatten(x, start_dim=1) + + inp = torch.rand((5, 2, 2)) + verify_model(Flatten(), input_data=inp) + verify_model(BatchFlatten(), input_data=inp) + + def test_forward_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1311,12 +1326,17 @@ class ToLong(Module): def forward(self, x): return x.long() + class ToDouble(Module): + def forward(self, x): + return x.double() + verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32))) verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int)) verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int)) verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32))) verify_model(ToInt().eval(), torch.tensor(0.8)) verify_model(ToLong().eval(), torch.tensor(0.8)) + verify_model(ToDouble().eval(), torch.tensor(0.8)) def test_adaptive_pool3d(): @@ -2901,6 +2921,7 @@ def test_forward_pretrained_bert_base_uncased(): test_forward_upsample3d() test_forward_nms() test_to() + test_flatten() test_type_as() test_forward_functional_pad() test_forward_zero_pad2d()