Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[intrin] a few more math functions #5468

Merged
merged 1 commit into from
Apr 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,13 @@ TVM_DECLARE_INTRIN_UNARY(cos);
TVM_DECLARE_INTRIN_UNARY(cosh);
TVM_DECLARE_INTRIN_UNARY(sin);
TVM_DECLARE_INTRIN_UNARY(sinh);
TVM_DECLARE_INTRIN_UNARY(asin);
TVM_DECLARE_INTRIN_UNARY(acos);
TVM_DECLARE_INTRIN_UNARY(atan);
TVM_DECLARE_INTRIN_UNARY(acosh);
TVM_DECLARE_INTRIN_UNARY(asinh);
TVM_DECLARE_INTRIN_UNARY(atanh);


namespace tir {
/*!
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, all, any, min_value, max_value, trace
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp
from .op import cos, sin, cosh, sinh, tan, tanh, atan, atan2
from .op import sin, sinh, asin, asinh
from .op import cos, cosh, acos, acosh
from .op import tan, tanh, atan, atan2, atanh
from .op import erf, sigmoid, sqrt, rsqrt, floor, ceil, hypot
from .op import trunc, abs, round, nextafter, nearbyint, power, popcount, fmod, if_then_else
from .op import isnan, isfinite, isinf, copysign
Expand Down
80 changes: 80 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,38 @@ def cosh(x):
return call_pure_intrin(x.dtype, "cosh", x)


def acos(x):
"""Take acos of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "acos", x)


def acosh(x):
"""Take acos of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "acosh", x)


def sin(x):
"""Take sin of input x.

Expand Down Expand Up @@ -554,6 +586,38 @@ def sinh(x):
return call_pure_intrin(x.dtype, "sinh", x)


def asin(x):
"""Take asin of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "asin", x)


def asinh(x):
"""Take asinh of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "asinh", x)


def atan(x):
"""Take atan of input x.

Expand All @@ -570,6 +634,22 @@ def atan(x):
return call_pure_intrin(x.dtype, "atan", x)


def atanh(x):
"""Take atanh of input x.

Parameters
----------
x : PrimExpr
Input argument.

Returns
-------
y : PrimExpr
The result.
"""
return call_pure_intrin(x.dtype, "atanh", x)


def atan2(x1, x2):
"""Take arctan2(x1, x2).

Expand Down
21 changes: 18 additions & 3 deletions src/target/intrin_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,37 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atanh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acos")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cosh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.acosh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sin")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asin")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.sinh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.atan2")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.asinh")
.set_body(DispatchExtern<FloatSuffix>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.hypot")
Expand Down
8 changes: 7 additions & 1 deletion tests/python/unittest/test_tir_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def test_unary_intrin():
(tvm.tir.sinh, lambda x : np.sinh(x)),
(tvm.tir.cosh, lambda x : np.cosh(x)),
(tvm.tir.log1p, lambda x : np.log1p(x)),
(tvm.tir.asin, lambda x : np.arcsin(x)),
(tvm.tir.acos, lambda x : np.arccos(x)),
(tvm.tir.atan, lambda x : np.arctan(x)),
(tvm.tir.asinh, lambda x : np.arcsinh(x)),
(tvm.tir.acosh, lambda x : np.arccosh(x)),
(tvm.tir.atanh, lambda x : np.arctanh(x)),
]
def run_test(tvm_intrin, np_func):
m = te.var("m",)
Expand All @@ -72,7 +78,7 @@ def run_test(tvm_intrin, np_func):
f = tvm.build(s, [A, B], "llvm")
ctx = tvm.cpu(0)
n = 10
a = tvm.nd.array(np.random.uniform(0, 1, size=n).astype(A.dtype), ctx)
a = tvm.nd.array(np.random.uniform(0.1, 0.5, size=n).astype(A.dtype), ctx)
b = tvm.nd.array( \
np.random.uniform(size=n).astype(A.dtype), ctx)
f(a, b)
Expand Down