Skip to content

Commit

Permalink
Fast exponent
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github committed Jan 29, 2020
1 parent 1b8522e commit 12358c4
Showing 1 changed file with 67 additions and 1 deletion.
68 changes: 67 additions & 1 deletion topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ using namespace tvm::te;
}, name, tag); \
}

TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(erf);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
Expand Down Expand Up @@ -360,5 +359,72 @@ inline Tensor full_like(const Tensor& x,
}, name, tag);
}

/*
* \brief Fast exponential function implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h#L183
* Exponential function. Works by writing "x = m*log(2) + r" where
* "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
* "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).
*/
inline Tensor fast_exp_float(const Tensor& _x,
std::string name,
std::string tag) {
auto cst_1 = make_const(DataType::Float(32), 1.0f);
auto cst_half = make_const(DataType::Float(32), 0.5f);
auto cst_exp_hi = make_const(DataType::Float(32), 88.3762626647950f);
auto cst_exp_lo = make_const(DataType::Float(32), -88.3762626647949f);
auto cst_cephes_LOG2EF = make_const(DataType::Float(32), 1.44269504088896341f);
auto cst_cephes_exp_p0 = make_const(DataType::Float(32), 1.9875691500E-4f);
auto cst_cephes_exp_p1 = make_const(DataType::Float(32), 1.3981999507E-3f);
auto cst_cephes_exp_p2 = make_const(DataType::Float(32), 8.3334519073E-3f);
auto cst_cephes_exp_p3 = make_const(DataType::Float(32), 4.1665795894E-2f);
auto cst_cephes_exp_p4 = make_const(DataType::Float(32), 1.6666665459E-1f);
auto cst_cephes_exp_p5 = make_const(DataType::Float(32), 5.0000001201E-1f);
auto cst_nln2 = make_const(DataType::Float(32), -0.6931471805599453f);
auto cst_127 = make_const(DataType::Float(32), 127.0f);

return compute(_x->shape,
[&](const Array<Var>& i) {
// Clamp x.
auto x = ::tvm::max(::tvm::min(_x(i), cst_exp_hi), cst_exp_lo);

// Express exp(x) as exp(m*ln(2) + r), start by extracting
// m = floor(x/ln(2) + 0.5).
auto m = ::tvm::floor(x * cst_cephes_LOG2EF + cst_half);

// Get r = x - m*ln(2).
auto cst_nln2 = make_const(DataType::Float(32), -0.6931471805599453f);
auto r = m * cst_nln2 + x;
auto r2 = r * r;

auto y = cst_cephes_exp_p0;
y = y * r + cst_cephes_exp_p1;
y = y * r + cst_cephes_exp_p2;
y = y * r + cst_cephes_exp_p3;
y = y * r + cst_cephes_exp_p4;
y = y * r + cst_cephes_exp_p5;
y = y * r2 + r;
y = y + cst_1;

// Return 2^m * exp(r).
auto ei = ::tvm::cast(DataType::Int(32), m + cst_127) << 23;
auto ef = ::tvm::reinterpret(DataType::Float(32), ei);
return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
},
name, tag);
}

inline Tensor exp(const Tensor& x,
std::string name = "T_exp",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
return fast_exp_float(x, name, tag);
} else {
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::exp(x(i));
}, name, tag);
}
}

} // namespace topi
#endif // TOPI_ELEMWISE_H_

0 comments on commit 12358c4

Please sign in to comment.