Skip to content

Commit

Permalink
fast tanh
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 committed May 31, 2019
1 parent 1fdf111 commit fe05f22
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 10 deletions.
72 changes: 69 additions & 3 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -31,6 +31,7 @@
#include "tvm/tvm.h"
#include "tvm/ir.h"
#include "tvm/ir_pass.h"
#include "broadcast.h"

namespace topi {
using namespace tvm;
Expand All @@ -46,7 +47,6 @@ using namespace tvm;
}

TOPI_DECLARE_UNARY_OP(exp);
TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(sigmoid);
TOPI_DECLARE_UNARY_OP(sqrt);
TOPI_DECLARE_UNARY_OP(log);
Expand All @@ -56,6 +56,72 @@ TOPI_DECLARE_UNARY_OP(round);
TOPI_DECLARE_UNARY_OP(trunc);
TOPI_DECLARE_UNARY_OP(abs);


/* \brief Fast_tanh_float implementation from Eigen */
inline Tensor fast_tanh_float(const Tensor& in,
std::string name,
std::string tag) {
// Clamp the inputs to the range [-9, 9] since anything outside
// this range is +/-1.0f in single-precision.
auto x = maximum(minimum(in, make_const(in->dtype, 9.0)), make_const(in->dtype, -9.0));

// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03);
auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04);
auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05);
auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08);
auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11);
auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13);
auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16);

// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(in->dtype, 4.89352518554385e-03);
auto beta_2 = make_const(in->dtype, 2.26843463243900e-03);
auto beta_4 = make_const(in->dtype, 1.18534705686654e-04);
auto beta_6 = make_const(in->dtype, 1.19825839466702e-06);

return compute(x->shape,
[&](const Array<Var>& i) {
auto x2 = x(i) * x(i);
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x(i) * p;

auto q = x2 * beta_6 + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;
return p / q;
},
name, tag);
}

/*!
* \brief Creates an operation that returns hyperbolic tanh of a given tensor
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is tanh
*/
inline Tensor tanh(const Tensor& x,
std::string name = "T_tanh",
std::string tag = kElementWise) {
if (x->dtype == Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
} else {
// fallback to default implementation
return compute(x->shape, [&](const Array<Var>& i) {
return ::tvm::tanh(x(i));
}, name, tag);
}
}

/*!
* \brief Creates an operation that returns identity of a given tensor
*
Expand Down
14 changes: 7 additions & 7 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ def test_util():


def test_ewise():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
def test_apply(func, name, f_numpy, low, high, dtype=tvm.float32, check_round=False, skip_name_check=False):
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), dtype=dtype, name='A')

shape = (20, 3)

def test_apply(func, name, f_numpy, low, high, check_round=False, skip_name_check=False):
shape = (20, 3)
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
if not skip_name_check:
Expand Down Expand Up @@ -71,7 +70,8 @@ def check_device(device):
test_apply(topi.abs, "fabs", np.abs, -100, 100)
test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
test_apply(topi.exp, "exp", np.exp, -1, 1)
test_apply(topi.tanh, "tanh", np.tanh, -10, 10)
test_apply(topi.tanh, "tanh", np.tanh, -20, 20)
test_apply(topi.tanh, "tanh", np.tanh, -20, 20, dtype="float64")
test_apply(topi.sigmoid, "sigmoid", lambda x:1/(1+np.exp(-x)), -1, 1)
test_apply(topi.log, "log", np.log, 0, 100)
test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
Expand Down

0 comments on commit fe05f22

Please sign in to comment.