Skip to content

Commit

Permalink
Fast exponent (apache#4790)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgl-github authored and alexwong committed Feb 28, 2020
1 parent 719a91a commit bca1935
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 0 deletions.
80 changes: 80 additions & 0 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,85 @@ inline Tensor full_like(const Tensor& x,
}, name, tag);
}

/*!
* \brief Fast exponential function implementation
*
* \param _x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is exponent operation
*
* \note Function computes:
* log2(e^x) = x * log2(e) * log2(2) =>
* log2(e^x) = log2(2^(x*log2(e))) =>
* e^x = 2^(x*log2(e))
* Splitting power x*log2(e) into integer and fractional parts:
* e^(n+f) = e^n * e^f
* n = floor(x*log2(e) + 1/2)
* f = x - n * ln(2)
* exp(x) = 2^n * exp(y)
* Approximation for fractional part:
* y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2))
*/
inline Tensor fast_exp_float32(const Tensor& _x,
std::string name,
std::string tag) {
auto x_hi = make_const(DataType::Float(32), 88.3762626647950f);
auto x_lo = make_const(DataType::Float(32), -88.3762626647949f);
auto log2e = make_const(DataType::Float(32), 1.44269504088896341f);
auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f);
PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f),
make_const(DataType::Float(32), 1.3981999507E-3f),
make_const(DataType::Float(32), 8.3334519073E-3f),
make_const(DataType::Float(32), 4.1665795894E-2f),
make_const(DataType::Float(32), 1.6666665459E-1f),
make_const(DataType::Float(32), 5.0000001201E-1f)};
auto one = make_const(DataType::Float(32), 1.0f);
auto one_half = make_const(DataType::Float(32), 0.5f);
auto b = make_const(DataType::Float(32), 127.0f);

return compute(_x->shape,
[&](const Array<Var>& i) {
// clamp x
auto x = ::tvm::max(::tvm::min(_x(i), x_hi), x_lo);
// integer part
auto n = ::tvm::floor(x * log2e + one_half);
// fractional part
auto f = x - n * ln2;
auto y = (((((p[0] * f + p[1]) * f + p[2]) * f + p[3])* f+ p[4]) * f
+ p[5]) * f * f + f + one;
// Return 2^m * exp(r).
auto ef = tvm::reinterpret(DataType::Float(32),
::tvm::cast(DataType::Int(32), n + b) << 23);
return ::tvm::max(ef * y, _x(i)); // NOLINT(*)
},
name, tag);
}


/*!
* \brief Fast exponential function implementation
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is exponent operation
*
*/
inline Tensor fast_exp(const Tensor& x,
std::string name = "T_fast_exp",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_exp_float32(x, name, tag);
return ret;
} else {
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::exp(x(i));
}, name, tag);
}
}

} // namespace topi
#endif // TOPI_ELEMWISE_H_
16 changes: 16 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,19 @@ def reinterpret(x, dtype):
The result.
"""
return cpp.reinterpret(x, dtype)


def fast_exp(x):
"""Take exponential of input x using fast_exp implementation
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
5 changes: 5 additions & 0 deletions topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ TVM_REGISTER_GLOBAL("topi.exp")
*rv = exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fast_exp")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_exp(args[0]);
});

TVM_REGISTER_GLOBAL("topi.erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = erf(args[0]);
Expand Down
38 changes: 38 additions & 0 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,45 @@ def verify(from_dtype, to_dtype, low=-100, high=100):
verify("bool", "int32")


def test_fastmath():
def test_apply(
func,
name,
f_numpy,
low,
high,
step,
dtype=tvm.float32
):
a_np = np.arange(low, high, step).astype(dtype)
b_np = f_numpy(a_np)
A = tvm.placeholder(a_np.shape, dtype=dtype, name="A")
B = func(A)
assert tuple(B.shape) == tuple(A.shape)

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
func = tvm.build(s, [A, B], device, name=name)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros_like(b_np), ctx)
func(a, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)

check_device('llvm')
check_device('llvm -device=arm-cpu')


test_apply(topi.fast_exp, "fast_exp", np.exp,
low=-88, high=88,
step = 0.01)

if __name__ == "__main__":
test_util()
test_ewise()
test_cast()
test_fastmath()

0 comments on commit bca1935

Please sign in to comment.