diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0845fec4ac82..269eb4c67e6a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -442,6 +442,36 @@ def _impl(inputs, input_types): scale=scale)[0] return _impl +def _instance_norm(): + def _impl(inputs, input_types): + data = inputs[0] + data_type = input_types[0] + channels = _infer_shape(data) + + if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr): + scale = center = True + weight = inputs[1] + beta = inputs[2] + gamma = weight + else: + scale = center = False + + if not scale: + gamma = _create_typed_const(np.ones([int(channels[1])]), data_type) + + if not center: + beta = _create_typed_const(np.zeros([int(channels[1])]), data_type) + + epsilon = float(inputs[7]) + return _op.nn.instance_norm(data, + gamma, + beta, + axis=1, + epsilon=epsilon, + center=center, + scale=scale) + return _impl + def _transpose(): def _impl(inputs, input_types): data = inputs[0] @@ -965,6 +995,7 @@ def _wrap_const(c): "aten::threshold_" : _threshold(), "aten::contiguous" : _contiguous(), "aten::batch_norm" : _batch_norm(), + "aten::instance_norm" : _instance_norm(), "aten::transpose" : _transpose(), "aten::transpose_" : _transpose(), "aten::t" : _transpose(), @@ -978,6 +1009,8 @@ def _wrap_const(c): "aten::avg_pool2d" : _avg_pool2d(), "aten::dropout" : _dropout(), "aten::dropout_" : _dropout(), + "aten::feature_dropout" : _dropout(), + "aten::alpha_dropout" : _dropout(), "aten::mean" : _mean(), "aten::chunk" : _chunk(), "aten::matmul" : _matmul(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index ccc9a39fe20d..1f083cbf5c02 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -512,6 +512,15 @@ def init_weight(m): verify_model(bn.eval(), input_data=inp) +def test_forward_instancenorm(): + inp_2d = torch.rand((1, 16, 10, 10)) + inp_3d = torch.rand((1, 16, 10, 10, 10)) + + for ins_norm, inp in [(torch.nn.InstanceNorm2d(16), inp_2d), + (torch.nn.InstanceNorm3d(16), inp_3d)]: + verify_model(ins_norm.eval(), input_data=inp) + + def test_forward_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -619,13 +628,11 @@ def forward(self, *args): def test_forward_dropout(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] - - class Dropout1(Module): - def forward(self, *args): - return torch.nn.functional.dropout(args[0][0, 0], 0.5, False) - input_data = torch.rand(input_shape).float() - verify_model(Dropout1().float().eval(), input_data=input_data) + verify_model(torch.nn.Dropout(p=0.5).eval(), input_data=input_data[0, 0]) + verify_model(torch.nn.Dropout2d(p=0.5).eval(), input_data=input_data[0]) + verify_model(torch.nn.Dropout3d(p=0.5).eval(), input_data=input_data) + verify_model(torch.nn.AlphaDropout(p=0.5).eval(), input_data=input_data[0, 0]) def test_forward_slice(): torch.set_grad_enabled(False) @@ -1080,6 +1087,7 @@ def forward(self, xs): test_forward_threshold() test_forward_contiguous() test_forward_batchnorm() + test_forward_instancenorm() test_forward_transpose() test_forward_size() test_forward_view()