From 16288ad578b53e5d17b08efbb66cbd9b6da2be43 Mon Sep 17 00:00:00 2001 From: Alex Gladkov Date: Mon, 17 Feb 2020 09:22:11 -0800 Subject: [PATCH] Fast exponent (#4790) --- topi/include/topi/elemwise.h | 80 +++++++++++++++++++++++++++++ topi/python/topi/math.py | 16 ++++++ topi/src/topi.cc | 5 ++ topi/tests/python/test_topi_math.py | 38 ++++++++++++++ 4 files changed, 139 insertions(+) diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index e3f4678c1163..e35e3e424d6e 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -377,5 +377,85 @@ inline Tensor full_like(const Tensor& x, }, name, tag); } +/*! + * \brief Fast exponential function implementation + * + * \param _x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is exponent operation + * + * \note Function computes: + * log2(e^x) = x * log2(e) * log2(2) => + * log2(e^x) = log2(2^(x*log2(e))) => + * e^x = 2^(x*log2(e)) + * Splitting power x*log2(e) into integer and fractional parts: + * e^(n+f) = e^n * e^f + * n = floor(x*log2(e) + 1/2) + * f = x - n * ln(2) + * exp(x) = 2^n * exp(y) + * Approximation for fractional part: + * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) + */ +inline Tensor fast_exp_float32(const Tensor& _x, + std::string name, + std::string tag) { + auto x_hi = make_const(DataType::Float(32), 88.3762626647950f); + auto x_lo = make_const(DataType::Float(32), -88.3762626647949f); + auto log2e = make_const(DataType::Float(32), 1.44269504088896341f); + auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f); + PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f), + make_const(DataType::Float(32), 1.3981999507E-3f), + make_const(DataType::Float(32), 8.3334519073E-3f), + make_const(DataType::Float(32), 4.1665795894E-2f), + make_const(DataType::Float(32), 1.6666665459E-1f), + make_const(DataType::Float(32), 5.0000001201E-1f)}; + auto one = make_const(DataType::Float(32), 1.0f); + auto one_half = make_const(DataType::Float(32), 0.5f); + auto b = make_const(DataType::Float(32), 127.0f); + + return compute(_x->shape, + [&](const Array& i) { + // clamp x + auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo); + // integer part + auto n = ::tvm::floor(x * log2e + one_half); + // fractional part + auto f = x - n * ln2; + auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f + + p[5]) * f * f + f + one; + // Return 2^m * exp(r). + auto ef = tvm::reinterpret(DataType::Float(32), + ::tvm::cast(DataType::Int(32), n + b) << 23); + return ::tvm::max(ef * y, _x(i)); // NOLINT(*) + }, + name, tag); +} + + +/*! + * \brief Fast exponential function implementation + * + * \param x The input tensor + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is exponent operation + * + */ +inline Tensor fast_exp(const Tensor& x, + std::string name = "T_fast_exp", + std::string tag = kElementWise) { + if (x->dtype == DataType::Float(32)) { + auto ret = fast_exp_float32(x, name, tag); + return ret; + } else { + return compute(x->shape, [&](const Array& i) { + return ::tvm::exp(x(i)); + }, name, tag); + } +} + } // namespace topi #endif // TOPI_ELEMWISE_H_ diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index c3e1a102471e..148d53a54cfe 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -451,3 +451,19 @@ def reinterpret(x, dtype): The result. """ return cpp.reinterpret(x, dtype) + + +def fast_exp(x): + """Take exponential of input x using fast_exp implementation + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return cpp.fast_exp(x, x.dtype, tag.ELEMWISE) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 2b2142bb5759..a7b916093d98 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("topi.exp") *rv = exp(args[0]); }); +TVM_REGISTER_GLOBAL("topi.fast_exp") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = fast_exp(args[0]); + }); + TVM_REGISTER_GLOBAL("topi.erf") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = erf(args[0]); diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index bb674364ff2e..5bb95ba10e3b 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -185,7 +185,45 @@ def verify(from_dtype, to_dtype, low=-100, high=100): verify("bool", "int32") +def test_fastmath(): + def test_apply( + func, + name, + f_numpy, + low, + high, + step, + dtype=tvm.float32 + ): + a_np = np.arange(low, high, step).astype(dtype) + b_np = f_numpy(a_np) + A = tvm.placeholder(a_np.shape, dtype=dtype, name="A") + B = func(A) + assert tuple(B.shape) == tuple(A.shape) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + func = tvm.build(s, [A, B], device, name=name) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.zeros_like(b_np), ctx) + func(a, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) + + check_device('llvm') + check_device('llvm -device=arm-cpu') + + + test_apply(topi.fast_exp, "fast_exp", np.exp, + low=-88, high=88, + step = 0.01) + if __name__ == "__main__": test_util() test_ewise() test_cast() + test_fastmath()