Skip to content

Commit

Permalink
[RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support (apache#5395)
Browse files Browse the repository at this point in the history
* [RELAY][PYTORCH]cosh,sinh,log2,log10,log1p op support

* Review comment fixed

* Gradient testcase added
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent 73dd012 commit f782a62
Show file tree
Hide file tree
Showing 14 changed files with 292 additions and 3 deletions.
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def _impl(inputs, input_types):
return _impl


def _log1p():
def _impl(inputs, input_types):
# 1_plus_log x = log(x + 1)
one = _expr.const(1, dtype="float32")
return _op.log(inputs[0] + one)
return _impl


def _arange():
def _impl(inputs, input_types):
if len(inputs) == 5:
Expand Down Expand Up @@ -1648,11 +1656,16 @@ def _get_convert_map(prelude):
"aten::abs" : _unary("abs"),
"aten::neg" : _unary("negative"),
"aten::cos" : _unary("cos"),
"aten::cosh" : _unary("cosh"),
"aten::sin" : _unary("sin"),
"aten::sinh" : _unary("sinh"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::log2" : _unary("log2"),
"aten::log10" : _unary("log10"),
"aten::log1p" : _log1p(),
"aten::exp" : _unary("exp"),
"aten::erf" : _unary("erf"),
"aten::trunc" : _unary("trunc"),
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@


register_broadcast_schedule("log")
register_broadcast_schedule("log2")
register_broadcast_schedule("log10")
register_broadcast_schedule("tan")
register_broadcast_schedule("cos")
register_broadcast_schedule("cosh")
register_broadcast_schedule("sin")
register_broadcast_schedule("sinh")
register_broadcast_schedule("atan")
register_broadcast_schedule("exp")
register_broadcast_schedule("erf")
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@
from .reduce import sum as _sum
from .tensor import (
cos,
cosh,
exp,
less,
negative,
ones_like,
power,
sin,
sinh,
zeros_like,
equal,
shape_of,
Expand Down Expand Up @@ -61,6 +63,24 @@ def log_grad(orig, grad):
return [grad * ones_like(x) / x]


@register_gradient("log2")
def log2_grad(orig, grad):
"""Returns [grad * 1 / (log(2) * x)]"""
x = orig.args[0]
ones = ones_like(x)
two = const(2.0)
return [grad * ones / (log(two) * x)]


@register_gradient("log10")
def log10_grad(orig, grad):
"""Returns [grad * 1 / (log(10) * x)]"""
x = orig.args[0]
ones = ones_like(x)
ten = const(10.0)
return [grad * ones / (log(ten) * x)]


@register_gradient("tan")
def tan_grad(orig, grad):
"""Returns [grad / (cos^2(x))]"""
Expand All @@ -76,12 +96,26 @@ def cos_grad(orig, grad):
return [grad * (-ones * sin(x))]


@register_gradient("cosh")
def cosh_grad(orig, grad):
"""Returns [grad * (-sinh(x))]"""
x = orig.args[0]
ones = ones_like(x)
return [grad * (-ones * sinh(x))]


@register_gradient("sin")
def sin_grad(orig, grad):
"""Returns [grad * cos(x)]"""
x = orig.args[0]
return [grad * cos(x)]

@register_gradient("sinh")
def sinh_grad(orig, grad):
"""Returns [grad * cosh(x)]"""
x = orig.args[0]
return [grad * cosh(x)]

@register_gradient("atan")
def atan_grad(orig, grad):
"""Returns [grad * 1 / (1 + x ^ 2)]"""
Expand Down
60 changes: 60 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,36 @@ def log(data):
"""
return _make.log(data)

def log2(data):
"""Compute elementwise log to the base 2 of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log2(data)

def log10(data):
"""Compute elementwise log to the base 10 of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log10(data)

def tan(data):
"""Compute elementwise tan of data.
Expand Down Expand Up @@ -77,6 +107,21 @@ def cos(data):
"""
return _make.cos(data)

def cosh(data):
"""Compute elementwise cosh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.cosh(data)

def sin(data):
"""Compute elementwise sin of data.
Expand All @@ -92,6 +137,21 @@ def sin(data):
"""
return _make.sin(data)

def sinh(data):
"""Compute elementwise sinh of data.
Parameters
----------
data : relay.Expr
The input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sinh(data)

def atan(data):
"""Compute elementwise atan of data.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def where(condition, x, y):
Returns
-------
result : relay.Expr
The selected array.
The selected array.
Examples
--------
Expand Down
1 change: 1 addition & 0 deletions python/tvm/te/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# expose all operators in tvm tir.op
from tvm.tir import any, all, min_value, max_value, trace
from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil
from tvm.tir import sinh, cosh, log2, log10
from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else
from tvm.tir import isnan, isfinite, isinf
from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def sin(x):


def sinh(x):
"""Take sin of input x.
"""Take sinh of input x.
Parameters
----------
Expand Down
44 changes: 44 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ RELAY_REGISTER_UNARY_OP("log")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));


RELAY_REGISTER_UNARY_OP("log2")
.describe(R"code(Returns the log to base 2 of input array, computed element-wise.
.. math::
log2(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log2));


RELAY_REGISTER_UNARY_OP("log10")
.describe(R"code(Returns the log to base 10 of input array, computed element-wise.
.. math::
log10(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log10));


RELAY_REGISTER_UNARY_OP("tan")
.describe(R"code(Returns the tan of input array, computed element-wise.
Expand All @@ -73,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("cos")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cos));


RELAY_REGISTER_UNARY_OP("cosh")
.describe(R"code(Returns the cosh of input array, computed element-wise.
.. math::
Y = cosh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::cosh));


RELAY_REGISTER_UNARY_OP("sin")
.describe(R"code(Returns the sin of input array, computed element-wise.
Expand All @@ -84,6 +117,17 @@ RELAY_REGISTER_UNARY_OP("sin")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sin));


RELAY_REGISTER_UNARY_OP("sinh")
.describe(R"code(Returns the sinh of input array, computed element-wise.
.. math::
Y = sinh(X)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::sinh));


RELAY_REGISTER_UNARY_OP("atan")
.describe(R"code(Returns the atan of input array, computed element-wise.
Expand Down
12 changes: 12 additions & 0 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.erf")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log2")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log10")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log1p")
.set_body(DispatchExtern<FloatSuffix>);

Expand All @@ -49,9 +55,15 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);

Expand Down
25 changes: 25 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,26 @@ class Neg1(Module):
def forward(self, *args):
return torch.neg(args[0])

class Sinh1(Module):
def forward(self, *args):
return torch.sinh(args[0])

class Cosh1(Module):
def forward(self, *args):
return torch.cosh(args[0])

class Log2_1(Module):
def forward(self, *args):
return torch.log2(args[0])

class Log10_1(Module):
def forward(self, *args):
return torch.log10(args[0])

class Log1p_1(Module):
def forward(self, *args):
return torch.log1p(args[0])

input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Sqrt1().float().eval(), input_data=input_data)
Expand All @@ -1887,11 +1907,16 @@ def forward(self, *args):
verify_model(Floor1().float().eval(), input_data=input_data)
verify_model(Round1().float().eval(), input_data=input_data)
verify_model(Cos1().float().eval(), input_data=input_data)
verify_model(Cosh1().float().eval(), input_data=input_data)
verify_model(Sin1().float().eval(), input_data=input_data)
verify_model(Sinh1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Log2_1().float().eval(), input_data=input_data)
verify_model(Log10_1().float().eval(), input_data=input_data)
verify_model(Log1p_1().float().eval(), input_data=input_data)
verify_model(Exp1().float().eval(), input_data=input_data)
verify_model(Erf1().float().eval(), input_data=input_data)
verify_model(Trunc1().float().eval(), input_data=input_data)
Expand Down
6 changes: 5 additions & 1 deletion tests/python/relay/test_op_grad_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def check_single_op(opfunc, ref):
(tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
(tvm.relay.sin, lambda x: np.cos(x)),
(tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]:
(tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0))),
(tvm.relay.log2, lambda x: 1 / (np.log(2) * x)),
(tvm.relay.log10, lambda x: 1 / (np.log(10) * x)),
(tvm.relay.cosh, lambda x: -1.0 * np.sinh(x)),
(tvm.relay.sinh, lambda x: np.cosh(x))]:
check_single_op(opfunc, ref)


Expand Down
4 changes: 4 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,18 @@ TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
TOPI_DECLARE_UNARY_OP(log2);
TOPI_DECLARE_UNARY_OP(log10);
TOPI_DECLARE_UNARY_OP(floor);
TOPI_DECLARE_UNARY_OP(ceil);
TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);
TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(cosh);
TOPI_DECLARE_UNARY_OP(tan);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(sinh);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);
Expand Down
Loading

0 comments on commit f782a62

Please sign in to comment.