Skip to content

Commit

Permalink
[PYTORCH]Repeat, Reciprocal & Reshape Op support (#5280)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored Apr 10, 2020
1 parent 0d1babc commit b236565
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
42 changes: 42 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,34 @@ def _impl(inputs, input_types):
return _op.transform.take(data, index, axis=dim)
return _impl

def _reciprocal():
def _impl(inputs, input_types):
data = inputs[0]
return _expr.const(1.0) / data
return _impl

def _repeat():
def _impl(inputs, input_types):
data = inputs[0]
reps = _get_dims(inputs[1])
return _op.transform.tile(data, reps=reps)
return _impl

def _repeat_interleave():
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], int):
repeats = inputs[1]
axis = inputs[2]
else:
msg = "Only repeat with one value as repeat is currently supported."
raise AssertionError(msg)
if axis is None: # Flatten the data if no axis is given from torch
data = _op.transform.reshape(data, [-1])
axis = 0
return _op.transform.repeat(data, repeats=repeats, axis=axis)
return _impl

def _ones():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -675,6 +703,16 @@ def _impl(inputs, input_types):
return _op.transform.reshape(data, new_shape)
return _impl

def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], list):
new_shape = inputs[1]
else:
new_shape = _infer_shape(inputs[1])
return _op.transform.reshape(data, new_shape)
return _impl

def _clone():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1082,6 +1120,9 @@ def _wrap_const(c):
"aten::div_" : _elemwise("divide"),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
"aten::repeat" : _repeat(),
"aten::repeat_interleave" : _repeat_interleave(),
"aten::to" : _to(),
"aten::squeeze" : _squeeze(),
"aten::unsqueeze" : _unsqueeze(),
Expand Down Expand Up @@ -1122,6 +1163,7 @@ def _wrap_const(c):
"aten::addmm" : _dense(),
"aten::size" : _size(),
"aten::view" : _view(),
"aten::reshape" : _reshape(),
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
Expand Down
75 changes: 75 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,61 @@ def forward(self, *args):
verify_model(Multiply3().float().eval(), input_data=input_data)
verify_model(Multiply4().float().eval(), input_data=input_data)

def test_forward_reciprocal():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
class Reciprocal1(Module):
def forward(self, *args):
return args[0].reciprocal()

input_data = torch.rand(input_shape).float()
verify_model(Reciprocal1().float().eval(), input_data=input_data)

def test_forward_repeat():
torch.set_grad_enabled(False)
input_shape = [1, 3]
class Repeat1(Module):
def forward(self, *args):
return args[0].repeat(1, 1)

class Repeat2(Module):
def forward(self, *args):
return args[0].repeat(4, 2)

class Repeat3(Module):
def forward(self, *args):
return args[0].repeat(4, 2, 1)

input_data = torch.rand(input_shape).float()
verify_model(Repeat1().float().eval(), input_data=input_data)
verify_model(Repeat2().float().eval(), input_data=input_data)
verify_model(Repeat3().float().eval(), input_data=input_data)

def test_forward_repeat_interleave():
torch.set_grad_enabled(False)
input_shape = [2, 2, 3]
class RepeatInterleave1(Module):
def forward(self, *args):
return args[0].repeat_interleave(2)

class RepeatInterleave2(Module):
def forward(self, *args):
return args[0].repeat_interleave(3, dim=0)

class RepeatInterleave3(Module):
def forward(self, *args):
return args[0].repeat_interleave(2, dim=1)

class RepeatInterleave4(Module):
def forward(self, *args):
return args[0].repeat_interleave(4, dim=2)

input_data = torch.rand(input_shape).float()
verify_model(RepeatInterleave1().float().eval(), input_data=input_data)
verify_model(RepeatInterleave2().float().eval(), input_data=input_data)
verify_model(RepeatInterleave3().float().eval(), input_data=input_data)
verify_model(RepeatInterleave4().float().eval(), input_data=input_data)

def test_forward_unsqueeze():
torch.set_grad_enabled(False)
input_shape = [10, 10]
Expand Down Expand Up @@ -600,6 +655,22 @@ def init_weight(m):
init_weight(ln.eval())
verify_model(ln.eval(), input_data=inp)

def test_forward_reshape():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
new_shape = [2, 1, 10, 10]
class Reshape1(Module):
def forward(self, *args):
return args[0].reshape(new_shape)

class Reshape2(Module):
def forward(self, *args):
return args[0].reshape([-1])

input_data = torch.rand(input_shape).float()
verify_model(Reshape1().float().eval(), input_data=input_data)
verify_model(Reshape2().float().eval(), input_data=input_data)

def test_forward_transpose():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -1151,6 +1222,10 @@ def forward(self, xs):
test_forward_add()
test_forward_subtract()
test_forward_multiply()
test_forward_reshape()
test_forward_reciprocal()
test_forward_repeat()
test_forward_repeat_interleave()
test_forward_squeeze()
test_forward_unsqueeze()
test_forward_concatenate()
Expand Down

0 comments on commit b236565

Please sign in to comment.