Skip to content

Commit

Permalink
[PYTORCH]Abs, Arange, Softplus ops (apache#5295)
Browse files Browse the repository at this point in the history
* [PYTHON]Abs, Arange, Softplus ops

* Review comments updated
  • Loading branch information
siju-samuel authored and Trevor Morris committed Apr 16, 2020
1 parent 2be18c4 commit 715eb3e
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 0 deletions.
52 changes: 52 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,33 @@ def _impl(inputs, input_types):
return get_relay_op(name)(data0, data1)
return _impl

def _abs():
def _impl(inputs, input_types):
data = inputs[0]
return _op.abs(data)
return _impl

def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
dtype = "float" if "float" in input_types[0:1] else _convert_dtype_value(inputs[1])
start = _create_typed_const(0, dtype)
stop = _create_typed_const(inputs[0], dtype)
step = _create_typed_const(1, dtype)
elif len(inputs) == 7:
dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3])
start = _create_typed_const(inputs[0], dtype)
stop = _create_typed_const(inputs[1], dtype)
step = _create_typed_const(inputs[2], dtype)
else:
msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
raise AssertionError(msg)
return _op.transform.arange(start=start,
stop=stop,
step=step,
dtype=_convert_data_type(dtype))
return _impl

def _squeeze():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -732,6 +759,13 @@ def _impl(inputs, input_types):
return _op.tensor.sigmoid(data)
return _impl

def _softplus():
def _impl(inputs, input_types):
data = inputs[0]
beta = _expr.const(float(inputs[1]))
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
return _impl

def _avg_pool2d():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1044,6 +1078,21 @@ def _impl(inputs, input_types):
return _impl

# Helper functions for operator implementation
def _convert_dtype_value(val):
convert_torch_dtype_map = {7:"torch.float64",
6:"torch.float32",
5:"torch.float16",
4:"torch.int64",
3:"torch.int32",
2:"torch.int16",
1:"torch.int8",
0:"torch.unit8",
None:"torch.int64"} # Default is torch.int64
if val in convert_torch_dtype_map:
return convert_torch_dtype_map[val]
else:
msg = "Torch data type value %d is not handled yet." % (val)
raise NotImplementedError(msg)

def _convert_data_type(input_type):
if input_type in ["double", "torch.float64"]:
Expand Down Expand Up @@ -1118,6 +1167,8 @@ def _wrap_const(c):
"aten::pow" : _elemwise("power"),
"aten::div" : _elemwise("divide"),
"aten::div_" : _elemwise("divide"),
"aten::abs" : _abs(),
"aten::arange" : _arange(),
"aten::ones" : _ones(),
"aten::zeros" : _zeros(),
"aten::reciprocal" : _reciprocal(),
Expand Down Expand Up @@ -1167,6 +1218,7 @@ def _wrap_const(c):
"aten::clone" : _clone(),
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
"aten::softplus" : _softplus(),
"aten::avg_pool2d" : _avg_pool2d(),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
Expand Down
66 changes: 66 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,54 @@ def forward(self, *args):
verify_model(Squeeze1().float().eval(), input_data=input_data)
verify_model(Squeeze2().float().eval(), input_data=input_data)

def test_forward_arange():
torch.set_grad_enabled(False)

class Arange1(Module):
def forward(self, *args):
return torch.arange(5)
class Arange2(Module):
def forward(self, *args):
return torch.arange(2.5)
class Arange3(Module):
def forward(self, *args):
return torch.arange(1, 4)
class Arange4(Module):
def forward(self, *args):
return torch.arange(1, 2.5, 0.5)
class Arange5(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int32)
class Arange6(Module):
def forward(self, *args):
return torch.arange(start=1, end=6, step=2)
class Arange7(Module):
def forward(self, *args):
return torch.arange(1, 4, dtype=torch.float32)
class Arange8(Module):
def forward(self, *args):
return torch.arange(1, 2, 1, dtype=torch.int16)

verify_model(Arange1().float().eval())
verify_model(Arange2().float().eval())
verify_model(Arange3().float().eval())
verify_model(Arange4().float().eval())
verify_model(Arange5().float().eval())
verify_model(Arange6().float().eval())
verify_model(Arange7().float().eval())
verify_model(Arange8().float().eval())

def test_forward_abs():
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]

class Abs1(Module):
def forward(self, *args):
return args[0].abs()

input_data = torch.rand(input_shape).float()
verify_model(Abs1().float().eval(), input_data=input_data)

def test_forward_concatenate():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -445,6 +493,20 @@ def test_forward_selu():
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.SELU().eval(), input_data=input_data)

def test_forward_softplus():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.Softplus().eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=1.5, threshold=20).eval(), input_data=input_data)
verify_model(torch.nn.Softplus(beta=5, threshold=10).eval(), input_data=input_data)

def test_forward_softsign():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(torch.nn.Softsign().eval(), input_data=input_data)

def test_forward_log_sigmoid():
torch.set_grad_enabled(False)
input_shape = [10, 10]
Expand Down Expand Up @@ -1254,6 +1316,8 @@ def forward(self, xs):
test_forward_view()
test_forward_select()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
Expand All @@ -1264,6 +1328,8 @@ def forward(self, xs):
test_forward_mean()
test_forward_expand()
test_forward_pow()
test_forward_abs()
test_forward_arange()
test_forward_chunk()
test_forward_split()
test_upsample()
Expand Down

0 comments on commit 715eb3e

Please sign in to comment.