diff --git a/docs/api/python/intrin.rst b/docs/api/python/intrin.rst index da8b64243209..60141d020c9e 100644 --- a/docs/api/python/intrin.rst +++ b/docs/api/python/intrin.rst @@ -35,6 +35,7 @@ tvm.intrin tvm.ceil tvm.trunc tvm.round + tvm.nearbyint tvm.abs tvm.isnan @@ -52,5 +53,6 @@ tvm.intrin .. autofunction:: tvm.ceil .. autofunction:: tvm.trunc .. autofunction:: tvm.round +.. autofunction:: tvm.nearbyint .. autofunction:: tvm.abs .. autofunction:: tvm.isnan diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index b0e82e7fb50c..7ed1c47c8387 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -528,6 +528,14 @@ TVM_DLL Expr ceil(Expr x); */ TVM_DLL Expr round(Expr x); +/*! + * \brief Calculates std::nearbyint(x) + * \param x The input expression. + * \return The result expression. + * This is a faster alternate to round. + */ +TVM_DLL Expr nearbyint(Expr x); + /*! * \brief Calculate trunc(x) * \param x The input expression. diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 2a4ebfec135b..6a580d39486c 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -434,6 +434,29 @@ def round(x): return _make.round(x) +def nearbyint(x): + """Round elements of the array to the nearest integer. + This intrinsic uses llvm.nearbyint instead of llvm.round + which is faster but will results different from tvm.round. + Notably nearbyint rounds according to the rounding mode, + whereas tvm.round (llvm.round) ignores that. + For differences between the two see: + https://en.cppreference.com/w/cpp/numeric/math/round + https://en.cppreference.com/w/cpp/numeric/math/nearbyint + + Parameters + ---------- + x : Expr + Input argument. + + Returns + ------- + y : Expr + The result. + """ + return _make.nearbyint(x) + + def isnan(x): """Check if input value is Nan. diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index b1f9af4f6f75..7c70a6de63b3 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -50,6 +50,9 @@ TVM_REGISTER_API("make.ceil") TVM_REGISTER_API("make.round") .set_body_typed(tvm::round); +TVM_REGISTER_API("make.nearbyint") +.set_body_typed(tvm::nearbyint); + TVM_REGISTER_API("make.trunc") .set_body_typed(tvm::trunc); diff --git a/src/codegen/llvm/intrin_rule_llvm.cc b/src/codegen/llvm/intrin_rule_llvm.cc index f8824083fe5d..e324fa0dcf3b 100644 --- a/src/codegen/llvm/intrin_rule_llvm.cc +++ b/src/codegen/llvm/intrin_rule_llvm.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -59,6 +59,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint") +.set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh") .set_body([](const TVMArgs& targs, TVMRetValue* rv) { Expr e = targs[0]; diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index d7a40c133784..d1e387548832 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -521,6 +521,13 @@ Expr round(Expr x) { return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic); } +Expr nearbyint(Expr x) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value)); + return ir::Call::make(x.type(), "nearbyint", {x}, ir::Call::PureIntrinsic); +} + Expr trunc(Expr x) { using ir::FloatImm; const FloatImm* fx = x.as(); diff --git a/tests/python/unittest/test_tvm_intrin.py b/tests/python/unittest/test_tvm_intrin.py new file mode 100644 index 000000000000..23e921d3f1ce --- /dev/null +++ b/tests/python/unittest/test_tvm_intrin.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import topi +from tvm.contrib import util, clang +import numpy as np +import ctypes +import math + + +def test_nearbyint(): + m = tvm.var("m",) + A = tvm.placeholder((m,), name='A') + A_rounded = tvm.compute((m,), lambda *i: tvm.nearbyint(A(*i)), name='A') + s = tvm.create_schedule(A_rounded.op) + f = tvm.build(s, [A, A_rounded], "llvm") + ctx = tvm.cpu(0) + n = 10 + a = tvm.nd.array(np.random.uniform(high=100, size=n).astype(A.dtype), ctx) + a_rounded = tvm.nd.array( \ + np.random.uniform(size=n).astype(A_rounded.dtype), ctx) + f(a, a_rounded) + # Note that numpys rint rounds to nearest integer with + # ties to halfway is broken by rounding to even. + # So that 1.5 and 2.5 will round 2. + # This is the default rounding mode with libc as well. + # However one can set a different rounding mode and in that + # case numpy result might differ. + tvm.testing.assert_allclose( + a_rounded.asnumpy(), np.rint(a.asnumpy())) + + +if __name__ == "__main__": + test_nearbyint()