Skip to content

Commit

Permalink
[PYTORCH]ReplicationPad support added (apache#5708)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 18, 2020
1 parent 8133520 commit 21c137a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 21 deletions.
32 changes: 11 additions & 21 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,7 @@ def _impl(inputs, input_types):
return None
return _impl

def _pad():
def _pad(mode):
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], list):
Expand All @@ -1394,9 +1394,11 @@ def _impl(inputs, input_types):
# group into tuple of 2 ints
paddings = [paddings[i:i + 2] for i in range(0, len(paddings), 2)]

pad_value = inputs[2]
if mode == "constant":
return _op.nn.pad(data, paddings, pad_value=inputs[2], pad_mode=mode)
else:
return _op.nn.pad(data, paddings, pad_mode=mode)

return _op.nn.pad(data, paddings, pad_value)
return _impl


Expand Down Expand Up @@ -1654,22 +1656,6 @@ def _impl(inputs, input_types):
return _impl


def _reflection_pad2d():
def _impl(inputs, input_types):
if isinstance(inputs[1], list):
pad_list = inputs[1]
else:
pad_list = list(_infer_shape(inputs[1]))
padding_left = pad_list[0]
padding_right = pad_list[1]
padding_top = pad_list[2]
padding_bottom = pad_list[3]
paddings = [[0, 0], [0, 0], [padding_top, padding_bottom], [padding_left, padding_right]]

return _op.nn.mirror_pad(inputs[0], paddings, mode='REFLECT')
return _impl


# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
Expand Down Expand Up @@ -1836,7 +1822,12 @@ def _get_convert_map(prelude):
"aten::Int" : _int(),
"prim::NumToTensor" : _numtotensor(),
"prim::ImplicitTensorToNum" : _tensortonum(),
"aten::constant_pad_nd" : _pad(),
"aten::constant_pad_nd" : _pad("constant"),
"aten::reflection_pad1d" : _pad("reflect"),
"aten::reflection_pad2d" : _pad("reflect"),
"aten::replication_pad1d" : _pad("edge"),
"aten::replication_pad2d" : _pad("edge"),
"aten::replication_pad3d" : _pad("edge"),
"aten::permute" : _transpose(prelude),
"aten::sum" : _reduce("sum"),
"aten::prod" : _reduce("prod"),
Expand Down Expand Up @@ -1895,7 +1886,6 @@ def _get_convert_map(prelude):
"aten::embedding" : _embedding(),
"aten::one_hot" : _one_hot(),
"aten::mm" : _matmul(prelude),
"aten::reflection_pad2d" : _reflection_pad2d(),
"relay::tensor_array_stack" : _tensor_array_stack(prelude),
"aten::add" : _add(prelude),
"aten::add_" : _add(prelude),
Expand Down
40 changes: 40 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,15 @@ def test_forward_constant_pad3d():
verify_model(torch.nn.ConstantPad3d((3, 4, 5, 6, 0, 1), 3.5).eval(), inp)


def test_forward_reflection_pad1d():
inp = torch.rand((1, 2, 4))
verify_model(torch.nn.ReflectionPad1d(2).eval(), inp)
verify_model(torch.nn.ReflectionPad1d((3, 1)).eval(), inp)

inp = torch.rand((2, 4, 5))
verify_model(torch.nn.ReflectionPad1d((2, 3)).eval(), inp)


def test_forward_reflection_pad2d():
inp = torch.rand((1, 1, 3, 3))
verify_model(torch.nn.ReflectionPad2d(2).eval(), inp)
Expand All @@ -1125,6 +1134,33 @@ def test_forward_reflection_pad2d():
verify_model(torch.nn.ReflectionPad2d((1, 3, 2, 4)).eval(), inp)


def test_forward_replication_pad1d():
inp = torch.rand((1, 2, 4))
verify_model(torch.nn.ReplicationPad1d(2).eval(), inp)
verify_model(torch.nn.ReplicationPad1d((3, 1)).eval(), inp)

inp = torch.rand((2, 4, 5))
verify_model(torch.nn.ReplicationPad1d((2, 3)).eval(), inp)


def test_forward_replication_pad2d():
inp = torch.rand((1, 1, 3, 3))
verify_model(torch.nn.ReplicationPad2d(2).eval(), inp)
verify_model(torch.nn.ReplicationPad2d((1, 1, 2, 0)).eval(), inp)

inp = torch.rand((2, 4, 5, 6))
verify_model(torch.nn.ReplicationPad2d((1, 3, 2, 4)).eval(), inp)


def test_forward_replication_pad3d():
inp = torch.rand((1, 1, 3, 3, 3))
verify_model(torch.nn.ReplicationPad3d(3).eval(), inp)
verify_model(torch.nn.ReplicationPad3d((1, 1, 2, 2, 1, 1)).eval(), inp)

inp = torch.rand((7, 5, 4, 5, 6))
verify_model(torch.nn.ReplicationPad3d((2, 3, 2, 5, 1, 4)).eval(), inp)


def test_forward_upsample3d():
inp = torch.arange(1, 9, dtype=torch.float32).view(1, 1, 2, 2, 2)
verify_model(torch.nn.Upsample(scale_factor=2, mode='nearest').eval(), inp)
Expand Down Expand Up @@ -2429,7 +2465,11 @@ def test_forward_pretrained_bert_base_uncased():
test_forward_constant_pad1d()
test_forward_constant_pad2d()
test_forward_constant_pad3d()
test_forward_reflection_pad1d()
test_forward_reflection_pad2d()
test_forward_replication_pad1d()
test_forward_replication_pad2d()
test_forward_replication_pad3d()
test_adaptive_pool3d()
test_conv3d()

Expand Down

0 comments on commit 21c137a

Please sign in to comment.