Skip to content

Commit

Permalink
[PYTORCH]Dropouts And InstanceNorm support added (apache#5203)
Browse files Browse the repository at this point in the history
* [PYTORCH]Dropouts And InstanceNorm support added

* Review comments fixed
  • Loading branch information
siju-samuel authored and Trevor Morris committed Apr 16, 2020
1 parent 4dacc71 commit 773dd80
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
33 changes: 33 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,36 @@ def _impl(inputs, input_types):
scale=scale)[0]
return _impl

def _instance_norm():
def _impl(inputs, input_types):
data = inputs[0]
data_type = input_types[0]
channels = _infer_shape(data)

if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr):
scale = center = True
weight = inputs[1]
beta = inputs[2]
gamma = weight
else:
scale = center = False

if not scale:
gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)

if not center:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)

epsilon = float(inputs[7])
return _op.nn.instance_norm(data,
gamma,
beta,
axis=1,
epsilon=epsilon,
center=center,
scale=scale)
return _impl

def _transpose():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -965,6 +995,7 @@ def _wrap_const(c):
"aten::threshold_" : _threshold(),
"aten::contiguous" : _contiguous(),
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::transpose" : _transpose(),
"aten::transpose_" : _transpose(),
"aten::t" : _transpose(),
Expand All @@ -978,6 +1009,8 @@ def _wrap_const(c):
"aten::avg_pool2d" : _avg_pool2d(),
"aten::dropout" : _dropout(),
"aten::dropout_" : _dropout(),
"aten::feature_dropout" : _dropout(),
"aten::alpha_dropout" : _dropout(),
"aten::mean" : _mean(),
"aten::chunk" : _chunk(),
"aten::matmul" : _matmul(),
Expand Down
20 changes: 14 additions & 6 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,15 @@ def init_weight(m):
verify_model(bn.eval(), input_data=inp)


def test_forward_instancenorm():
inp_2d = torch.rand((1, 16, 10, 10))
inp_3d = torch.rand((1, 16, 10, 10, 10))

for ins_norm, inp in [(torch.nn.InstanceNorm2d(16), inp_2d),
(torch.nn.InstanceNorm3d(16), inp_3d)]:
verify_model(ins_norm.eval(), input_data=inp)


def test_forward_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -619,13 +628,11 @@ def forward(self, *args):
def test_forward_dropout():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]

class Dropout1(Module):
def forward(self, *args):
return torch.nn.functional.dropout(args[0][0, 0], 0.5, False)

input_data = torch.rand(input_shape).float()
verify_model(Dropout1().float().eval(), input_data=input_data)
verify_model(torch.nn.Dropout(p=0.5).eval(), input_data=input_data[0, 0])
verify_model(torch.nn.Dropout2d(p=0.5).eval(), input_data=input_data[0])
verify_model(torch.nn.Dropout3d(p=0.5).eval(), input_data=input_data)
verify_model(torch.nn.AlphaDropout(p=0.5).eval(), input_data=input_data[0, 0])

def test_forward_slice():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1080,6 +1087,7 @@ def forward(self, xs):
test_forward_threshold()
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
Expand Down

0 comments on commit 773dd80

Please sign in to comment.