Skip to content

Commit

Permalink
[PYTORCH]LayerNorm support added (#5249)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored Apr 6, 2020
1 parent 5e50f47 commit 0cc2661
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,34 @@ def _impl(inputs, input_types):
scale=scale)
return _impl

def _get_dims(data):
import torch
if isinstance(data, _expr.Expr):
dims = _infer_shape(data)
elif isinstance(data, list):
dims = data
elif isinstance(data, (torch.Tensor, np.ndarray)):
dims = data.shape
else:
msg = "Data type %s could not be parsed" % type(data)
raise AssertionError(msg)
return dims

def _layer_norm():
def _impl(inputs, input_types):
data = inputs[0]
ndims = len(_get_dims(inputs[1]))
assert ndims == 1, "Support only normalization over last one dimension."

return _op.nn.layer_norm(data,
gamma=inputs[1],
beta=inputs[2],
axis=-1,
epsilon=float(inputs[4]),
center=False,
scale=False)
return _impl

def _transpose():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1050,6 +1078,7 @@ def _wrap_const(c):
"aten::contiguous" : _contiguous(),
"aten::batch_norm" : _batch_norm(),
"aten::instance_norm" : _instance_norm(),
"aten::layer_norm" : _layer_norm(),
"aten::transpose" : _transpose(),
"aten::transpose_" : _transpose(),
"aten::t" : _transpose(),
Expand Down
4 changes: 4 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ def test_forward_instancenorm():
(torch.nn.InstanceNorm3d(16), inp_3d)]:
verify_model(ins_norm.eval(), input_data=inp)

def test_forward_layernorm():
inp = torch.rand((20, 5, 10, 10))
verify_model(torch.nn.LayerNorm(10).eval(), input_data=inp)

def test_forward_transpose():
torch.set_grad_enabled(False)
Expand Down Expand Up @@ -1132,6 +1135,7 @@ def forward(self, xs):
test_forward_contiguous()
test_forward_batchnorm()
test_forward_instancenorm()
test_forward_layernorm()
test_forward_transpose()
test_forward_size()
test_forward_view()
Expand Down

0 comments on commit 0cc2661

Please sign in to comment.