From a2e7c0ee25cbebf94395aada6f6c73fecb79d381 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 Aug 2020 22:29:16 +0900 Subject: [PATCH 1/4] support cast to double and fix flatten conversion --- python/tvm/relay/frontend/pytorch.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 21cf9c3a1b97..948aead2491c 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -999,7 +999,20 @@ def _impl(inputs, input_types): def _flatten(): def _impl(inputs, input_types): data = inputs[0] - return _op.nn.batch_flatten(data) + start_dim = 0 + end_dim = -1 + + if len(inputs) > 0: + start_dim = inputs[1] + if len(inputs) > 1: + end_dim = inputs[2] + + if start_dim != 0 or end_dim != -1: + msg = "Only support flatten to 1d tensor" + raise NotImplementedError(msg) + + return _op.transform.reshape(data, (-1,)) + return _impl def _dense(): @@ -1509,11 +1522,13 @@ def _impl(inputs, input_types): # this happens when converting upsampling with scale factor cast_func = { 6: float, + 7: float, 3: int, 4: int } cast_func_expr = { 6: lambda x: _op.cast(x, "float32"), + 7: lambda x: _op.cast(x, "float64"), 3: lambda x: _op.cast(x, "int32"), 4: lambda x: _op.cast(x, "int64"), } From 4d4f8583706184770a60ebe7d44c4182810caad1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 Aug 2020 22:39:55 +0900 Subject: [PATCH 2/4] also support batch flatten, add test --- python/tvm/relay/frontend/pytorch.py | 10 +- tests/python/frontend/pytorch/test_forward.py | 312 ++++++++++-------- 2 files changed, 172 insertions(+), 150 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 948aead2491c..e0aafa4082c7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1007,11 +1007,13 @@ def _impl(inputs, input_types): if len(inputs) > 1: end_dim = inputs[2] - if start_dim != 0 or end_dim != -1: - msg = "Only support flatten to 1d tensor" - raise NotImplementedError(msg) + if start_dim == 0 and end_dim == -1: + return _op.transform.reshape(data, (-1,)) + if start_dim == 1 and end_dim == -1: + return _op.nn.batch_flatten(data) - return _op.transform.reshape(data, (-1,)) + msg = "Only support 1d flatten or batch flatten" + raise NotImplementedError(msg) return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 946712df5086..dfcc5cad35b9 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -881,6 +881,21 @@ def forward(self, *args): verify_model(Reshape1().float().eval(), input_data=input_data) verify_model(Reshape2().float().eval(), input_data=input_data) + +def test_transpose(): + class Flatten(Module): + def forward(self, x): + return torch.flatten(x) + + class BatchFlatten(Module): + def forward(self, x): + return torch.flatten(x, start_dim=1) + + inp = torch.rand((5, 2, 2)) + verify_model(Flatten(), input_data=inp) + verify_model(BatchFlatten(), input_data=inp) + + def test_forward_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1311,12 +1326,17 @@ class ToLong(Module): def forward(self, x): return x.long() + class ToDouble(Module): + def forward(self, x): + return x.double() + verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32))) verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int)) verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int)) verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32))) verify_model(ToInt().eval(), torch.tensor(0.8)) verify_model(ToLong().eval(), torch.tensor(0.8)) + verify_model(ToDouble().eval(), torch.tensor(0.8)) def test_adaptive_pool3d(): @@ -2804,150 +2824,150 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": - # some structural tests - test_forward_traced_function() - test_forward_dtypes() - test_weight_names() - test_duplicate_weight_use() - - # Single operator tests - test_forward_add() - test_forward_subtract() - test_forward_multiply() - test_forward_matmul() - test_forward_rsub() - test_forward_onehot() - test_forward_embedding() - test_forward_reshape() - test_forward_reciprocal() - test_forward_repeat() - test_forward_repeat_interleave() - test_forward_squeeze() - test_forward_unsqueeze() - test_forward_concatenate() - test_forward_reduce_sum() - test_forward_reduce_prod() - test_forward_argmin() - test_forward_argmax() - test_forward_norm() - test_forward_frobenius_norm() - test_forward_std() - test_forward_variance() - test_forward_relu() - test_forward_prelu() - test_forward_leakyrelu() - test_forward_elu() - test_forward_celu() - test_forward_gelu() - test_forward_selu() - test_forward_log_sigmoid() - test_forward_adaptiveavgpool() - test_forward_maxpool2d() - test_forward_maxpool1d() - test_forward_maxpool3d() - test_forward_hardtanh() - test_forward_conv() - test_forward_conv_transpose() - test_forward_threshold() - test_forward_contiguous() - test_forward_batchnorm() - test_forward_instancenorm() - test_forward_layernorm() - test_forward_groupnorm() - test_forward_transpose() - test_forward_size() - test_forward_view() - test_forward_select() - test_forward_take() - test_forward_topk() - test_forward_where() - test_forward_addcdiv() - test_forward_addcmul() - test_forward_clone() - test_forward_softplus() - test_forward_softsign() - test_forward_logsoftmax() - test_forward_sigmoid() - test_forward_dense() - test_forward_avgpool() - test_forward_avgpool3d() - test_forward_dropout() - test_forward_slice() - test_forward_mean() - test_forward_expand() - test_forward_pow() - test_forward_unary() - test_forward_clamp() - test_forward_logical_not() - test_forward_bitwise_not() - test_forward_bitwise_xor() - test_forward_logical_xor() - test_forward_isfinite() - test_forward_isnan() - test_forward_isinf() - test_forward_ones() - test_forward_ones_like() - test_forward_zeros() - test_forward_zeros_like() - test_forward_full() - test_forward_full_like() - test_forward_linspace() - test_forward_arange() - test_forward_mesh_grid() - test_forward_chunk() - test_forward_split() - test_forward_gather() - test_upsample() - test_forward_upsample3d() - test_forward_nms() + # # some structural tests + # test_forward_traced_function() + # test_forward_dtypes() + # test_weight_names() + # test_duplicate_weight_use() + + # # Single operator tests + # test_forward_add() + # test_forward_subtract() + # test_forward_multiply() + # test_forward_matmul() + # test_forward_rsub() + # test_forward_onehot() + # test_forward_embedding() + # test_forward_reshape() + # test_forward_reciprocal() + # test_forward_repeat() + # test_forward_repeat_interleave() + # test_forward_squeeze() + # test_forward_unsqueeze() + # test_forward_concatenate() + # test_forward_reduce_sum() + # test_forward_reduce_prod() + # test_forward_argmin() + # test_forward_argmax() + # test_forward_norm() + # test_forward_frobenius_norm() + # test_forward_std() + # test_forward_variance() + # test_forward_relu() + # test_forward_prelu() + # test_forward_leakyrelu() + # test_forward_elu() + # test_forward_celu() + # test_forward_gelu() + # test_forward_selu() + # test_forward_log_sigmoid() + # test_forward_adaptiveavgpool() + # test_forward_maxpool2d() + # test_forward_maxpool1d() + # test_forward_maxpool3d() + # test_forward_hardtanh() + # test_forward_conv() + # test_forward_conv_transpose() + # test_forward_threshold() + # test_forward_contiguous() + # test_forward_batchnorm() + # test_forward_instancenorm() + # test_forward_layernorm() + # test_forward_groupnorm() + # test_forward_transpose() + # test_forward_size() + # test_forward_view() + # test_forward_select() + # test_forward_take() + # test_forward_topk() + # test_forward_where() + # test_forward_addcdiv() + # test_forward_addcmul() + # test_forward_clone() + # test_forward_softplus() + # test_forward_softsign() + # test_forward_logsoftmax() + # test_forward_sigmoid() + # test_forward_dense() + # test_forward_avgpool() + # test_forward_avgpool3d() + # test_forward_dropout() + # test_forward_slice() + # test_forward_mean() + # test_forward_expand() + # test_forward_pow() + # test_forward_unary() + # test_forward_clamp() + # test_forward_logical_not() + # test_forward_bitwise_not() + # test_forward_bitwise_xor() + # test_forward_logical_xor() + # test_forward_isfinite() + # test_forward_isnan() + # test_forward_isinf() + # test_forward_ones() + # test_forward_ones_like() + # test_forward_zeros() + # test_forward_zeros_like() + # test_forward_full() + # test_forward_full_like() + # test_forward_linspace() + # test_forward_arange() + # test_forward_mesh_grid() + # test_forward_chunk() + # test_forward_split() + # test_forward_gather() + # test_upsample() + # test_forward_upsample3d() + # test_forward_nms() test_to() - test_type_as() - test_forward_functional_pad() - test_forward_zero_pad2d() - test_forward_constant_pad1d() - test_forward_constant_pad2d() - test_forward_constant_pad3d() - test_forward_reflection_pad1d() - test_forward_reflection_pad2d() - test_forward_replication_pad1d() - test_forward_replication_pad2d() - test_forward_replication_pad3d() - test_adaptive_pool3d() - test_conv3d() - test_conv3d_transpose() - test_forward_index() - - # Model tests - test_resnet18() - test_squeezenet1_0() - test_squeezenet1_1() - test_densenet121() - # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # test_inception_v3() - test_googlenet() - test_mnasnet0_5() - test_mobilenet_v2() - - test_custom_conversion_map() - - test_segmentaton_models() - test_3d_models() - - # Quantization test - from qnn_test import test_quantized_imagenet, test_quantized_modules - - test_quantized_modules() - test_quantized_imagenet() - - # Test simple conditionals and loop - test_control_flow() - test_simple_rnn() - - # More complex recurrent models - from lstm_test import custom_lstm_test - - custom_lstm_test() - - # Test bert model - test_forward_pretrained_bert_base_uncased() + # test_type_as() + # test_forward_functional_pad() + # test_forward_zero_pad2d() + # test_forward_constant_pad1d() + # test_forward_constant_pad2d() + # test_forward_constant_pad3d() + # test_forward_reflection_pad1d() + # test_forward_reflection_pad2d() + # test_forward_replication_pad1d() + # test_forward_replication_pad2d() + # test_forward_replication_pad3d() + # test_adaptive_pool3d() + # test_conv3d() + # test_conv3d_transpose() + # test_forward_index() + + # # Model tests + # test_resnet18() + # test_squeezenet1_0() + # test_squeezenet1_1() + # test_densenet121() + # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # # test_inception_v3() + # test_googlenet() + # test_mnasnet0_5() + # test_mobilenet_v2() + + # test_custom_conversion_map() + + # test_segmentaton_models() + # test_3d_models() + + # # Quantization test + # from qnn_test import test_quantized_imagenet, test_quantized_modules + + # test_quantized_modules() + # test_quantized_imagenet() + + # # Test simple conditionals and loop + # test_control_flow() + # test_simple_rnn() + + # # More complex recurrent models + # from lstm_test import custom_lstm_test + + # custom_lstm_test() + + # # Test bert model + # test_forward_pretrained_bert_base_uncased() From d40ca202928b0bd2cc7fe391c89bf6cd906b04fe Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 28 Aug 2020 22:48:35 +0900 Subject: [PATCH 3/4] add flatten test --- python/tvm/relay/frontend/pytorch.py | 5 +- tests/python/frontend/pytorch/test_forward.py | 295 +++++++++--------- 2 files changed, 151 insertions(+), 149 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index e0aafa4082c7..7e1985b9e16f 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -996,6 +996,7 @@ def _impl(inputs, input_types): return _op.transform.transpose(data, axes) return _impl + def _flatten(): def _impl(inputs, input_types): data = inputs[0] @@ -1012,11 +1013,11 @@ def _impl(inputs, input_types): if start_dim == 1 and end_dim == -1: return _op.nn.batch_flatten(data) - msg = "Only support 1d flatten or batch flatten" - raise NotImplementedError(msg) + raise NotImplementedError("Only support 1d flatten or batch flatten") return _impl + def _dense(): def _impl(inputs, input_types): use_bias = isinstance(inputs[0], _expr.Expr) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index dfcc5cad35b9..2e54ac4b4719 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -882,7 +882,7 @@ def forward(self, *args): verify_model(Reshape2().float().eval(), input_data=input_data) -def test_transpose(): +def test_flatten(): class Flatten(Module): def forward(self, x): return torch.flatten(x) @@ -2824,150 +2824,151 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": - # # some structural tests - # test_forward_traced_function() - # test_forward_dtypes() - # test_weight_names() - # test_duplicate_weight_use() - - # # Single operator tests - # test_forward_add() - # test_forward_subtract() - # test_forward_multiply() - # test_forward_matmul() - # test_forward_rsub() - # test_forward_onehot() - # test_forward_embedding() - # test_forward_reshape() - # test_forward_reciprocal() - # test_forward_repeat() - # test_forward_repeat_interleave() - # test_forward_squeeze() - # test_forward_unsqueeze() - # test_forward_concatenate() - # test_forward_reduce_sum() - # test_forward_reduce_prod() - # test_forward_argmin() - # test_forward_argmax() - # test_forward_norm() - # test_forward_frobenius_norm() - # test_forward_std() - # test_forward_variance() - # test_forward_relu() - # test_forward_prelu() - # test_forward_leakyrelu() - # test_forward_elu() - # test_forward_celu() - # test_forward_gelu() - # test_forward_selu() - # test_forward_log_sigmoid() - # test_forward_adaptiveavgpool() - # test_forward_maxpool2d() - # test_forward_maxpool1d() - # test_forward_maxpool3d() - # test_forward_hardtanh() - # test_forward_conv() - # test_forward_conv_transpose() - # test_forward_threshold() - # test_forward_contiguous() - # test_forward_batchnorm() - # test_forward_instancenorm() - # test_forward_layernorm() - # test_forward_groupnorm() - # test_forward_transpose() - # test_forward_size() - # test_forward_view() - # test_forward_select() - # test_forward_take() - # test_forward_topk() - # test_forward_where() - # test_forward_addcdiv() - # test_forward_addcmul() - # test_forward_clone() - # test_forward_softplus() - # test_forward_softsign() - # test_forward_logsoftmax() - # test_forward_sigmoid() - # test_forward_dense() - # test_forward_avgpool() - # test_forward_avgpool3d() - # test_forward_dropout() - # test_forward_slice() - # test_forward_mean() - # test_forward_expand() - # test_forward_pow() - # test_forward_unary() - # test_forward_clamp() - # test_forward_logical_not() - # test_forward_bitwise_not() - # test_forward_bitwise_xor() - # test_forward_logical_xor() - # test_forward_isfinite() - # test_forward_isnan() - # test_forward_isinf() - # test_forward_ones() - # test_forward_ones_like() - # test_forward_zeros() - # test_forward_zeros_like() - # test_forward_full() - # test_forward_full_like() - # test_forward_linspace() - # test_forward_arange() - # test_forward_mesh_grid() - # test_forward_chunk() - # test_forward_split() - # test_forward_gather() - # test_upsample() - # test_forward_upsample3d() - # test_forward_nms() + # some structural tests + test_forward_traced_function() + test_forward_dtypes() + test_weight_names() + test_duplicate_weight_use() + + # Single operator tests + test_forward_add() + test_forward_subtract() + test_forward_multiply() + test_forward_matmul() + test_forward_rsub() + test_forward_onehot() + test_forward_embedding() + test_forward_reshape() + test_forward_reciprocal() + test_forward_repeat() + test_forward_repeat_interleave() + test_forward_squeeze() + test_forward_unsqueeze() + test_forward_concatenate() + test_forward_reduce_sum() + test_forward_reduce_prod() + test_forward_argmin() + test_forward_argmax() + test_forward_norm() + test_forward_frobenius_norm() + test_forward_std() + test_forward_variance() + test_forward_relu() + test_forward_prelu() + test_forward_leakyrelu() + test_forward_elu() + test_forward_celu() + test_forward_gelu() + test_forward_selu() + test_forward_log_sigmoid() + test_forward_adaptiveavgpool() + test_forward_maxpool2d() + test_forward_maxpool1d() + test_forward_maxpool3d() + test_forward_hardtanh() + test_forward_conv() + test_forward_conv_transpose() + test_forward_threshold() + test_forward_contiguous() + test_forward_batchnorm() + test_forward_instancenorm() + test_forward_layernorm() + test_forward_groupnorm() + test_forward_transpose() + test_forward_size() + test_forward_view() + test_forward_select() + test_forward_take() + test_forward_topk() + test_forward_where() + test_forward_addcdiv() + test_forward_addcmul() + test_forward_clone() + test_forward_softplus() + test_forward_softsign() + test_forward_logsoftmax() + test_forward_sigmoid() + test_forward_dense() + test_forward_avgpool() + test_forward_avgpool3d() + test_forward_dropout() + test_forward_slice() + test_forward_mean() + test_forward_expand() + test_forward_pow() + test_forward_unary() + test_forward_clamp() + test_forward_logical_not() + test_forward_bitwise_not() + test_forward_bitwise_xor() + test_forward_logical_xor() + test_forward_isfinite() + test_forward_isnan() + test_forward_isinf() + test_forward_ones() + test_forward_ones_like() + test_forward_zeros() + test_forward_zeros_like() + test_forward_full() + test_forward_full_like() + test_forward_linspace() + test_forward_arange() + test_forward_mesh_grid() + test_forward_chunk() + test_forward_split() + test_forward_gather() + test_upsample() + test_forward_upsample3d() + test_forward_nms() test_to() - # test_type_as() - # test_forward_functional_pad() - # test_forward_zero_pad2d() - # test_forward_constant_pad1d() - # test_forward_constant_pad2d() - # test_forward_constant_pad3d() - # test_forward_reflection_pad1d() - # test_forward_reflection_pad2d() - # test_forward_replication_pad1d() - # test_forward_replication_pad2d() - # test_forward_replication_pad3d() - # test_adaptive_pool3d() - # test_conv3d() - # test_conv3d_transpose() - # test_forward_index() - - # # Model tests - # test_resnet18() - # test_squeezenet1_0() - # test_squeezenet1_1() - # test_densenet121() - # # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug - # # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 - # # test_inception_v3() - # test_googlenet() - # test_mnasnet0_5() - # test_mobilenet_v2() - - # test_custom_conversion_map() - - # test_segmentaton_models() - # test_3d_models() - - # # Quantization test - # from qnn_test import test_quantized_imagenet, test_quantized_modules - - # test_quantized_modules() - # test_quantized_imagenet() - - # # Test simple conditionals and loop - # test_control_flow() - # test_simple_rnn() - - # # More complex recurrent models - # from lstm_test import custom_lstm_test - - # custom_lstm_test() - - # # Test bert model - # test_forward_pretrained_bert_base_uncased() + test_flatten() + test_type_as() + test_forward_functional_pad() + test_forward_zero_pad2d() + test_forward_constant_pad1d() + test_forward_constant_pad2d() + test_forward_constant_pad3d() + test_forward_reflection_pad1d() + test_forward_reflection_pad2d() + test_forward_replication_pad1d() + test_forward_replication_pad2d() + test_forward_replication_pad3d() + test_adaptive_pool3d() + test_conv3d() + test_conv3d_transpose() + test_forward_index() + + # Model tests + test_resnet18() + test_squeezenet1_0() + test_squeezenet1_1() + test_densenet121() + # disable inception test for now, since loading it takes ~5min on torchvision-0.5 due to scipy bug + # See https://discuss.pytorch.org/t/torchvisions-inception-v3-takes-much-longer-to-load-than-other-models/68756 + # test_inception_v3() + test_googlenet() + test_mnasnet0_5() + test_mobilenet_v2() + + test_custom_conversion_map() + + test_segmentaton_models() + test_3d_models() + + # Quantization test + from qnn_test import test_quantized_imagenet, test_quantized_modules + + test_quantized_modules() + test_quantized_imagenet() + + # Test simple conditionals and loop + test_control_flow() + test_simple_rnn() + + # More complex recurrent models + from lstm_test import custom_lstm_test + + custom_lstm_test() + + # Test bert model + test_forward_pretrained_bert_base_uncased() From 32a78cb3542b53bfe0fe61965bcafe2adbd968ea Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 29 Aug 2020 05:22:42 +0900 Subject: [PATCH 4/4] clean up --- python/tvm/relay/frontend/pytorch.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 7e1985b9e16f..108d1d8dc1ad 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1000,13 +1000,8 @@ def _impl(inputs, input_types): def _flatten(): def _impl(inputs, input_types): data = inputs[0] - start_dim = 0 - end_dim = -1 - - if len(inputs) > 0: - start_dim = inputs[1] - if len(inputs) > 1: - end_dim = inputs[2] + start_dim = inputs[1] if len(inputs) > 0 else 0 + end_dim = inputs[2] if len(inputs) > 1 else -1 if start_dim == 0 and end_dim == -1: return _op.transform.reshape(data, (-1,))