Skip to content

Commit

Permalink
[Relay] Add fast_softmax (apache#7163)
Browse files Browse the repository at this point in the history
* [Relay] Add fast_softmax

* fix

* fix
  • Loading branch information
merrymercy authored Dec 25, 2020
1 parent 7dcafb0 commit 6ffd740
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 5 deletions.
5 changes: 5 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)


# fast softmax
reg.register_strategy("nn.fast_softmax", strategy.fast_softmax_strategy)
reg.register_pattern("nn.fast_softmax", OpPattern.OPAQUE)


# log_softmax
reg.register_schedule("nn.log_softmax", strategy.schedule_log_softmax)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,29 @@ def softmax(data, axis=-1):
return _make.softmax(data, axis)


def fast_softmax(data, axis=-1):
r"""Computes softmax.
Use approximation to compute exponent for faster speed.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
Parameters
----------
data: tvm.relay.Expr
The input data to the operator.
axis: int, optional
The axis to sum over when computing softmax
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.fast_softmax(data, axis)


def log_softmax(data, axis=-1):
r"""Computes log softmax.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,20 @@ def softmax_strategy(attrs, inputs, out_type, target):
return strategy


@override_native_generic_func("fast_softmax_strategy")
def fast_softmax_strategy(attrs, inputs, out_type, target):
"""fast softmax generic strategy"""
# NOTE: This op does not have an optimized manual schedule,
# so it should only be used together with auto-scheduler.
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.fast_softmax),
naive_schedule,
name="fast_softmax.generic",
)
return strategy


# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
Expand Down
45 changes: 42 additions & 3 deletions python/tvm/topi/nn/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
"""TVM operator for softmax and log_softmax compute."""
from __future__ import absolute_import
import tvm
from tvm import te
from tvm import te, topi


@tvm.te.tag_scope(tag="softmax_output")
def softmax(x, axis=-1):
"""Perform softmax activation on the data
"""Perform softmax activation on the data.
Parameters
----------
Expand All @@ -38,6 +38,32 @@ def softmax(x, axis=-1):
output : tvm.te.Tensor
output shape is the same as input
"""
return softmax_common(x, axis, False)


@tvm.te.tag_scope(tag="fast_softmax_output")
def fast_softmax(x, axis=-1):
"""Perform softmax activation on the data.
Use approximation to compute exponent for faster speed.
Parameters
----------
data : tvm.te.Tensor
can be any dimension
axis : int
channel axis
Returns
-------
output : tvm.te.Tensor
output shape is the same as input
"""
return softmax_common(x, axis, True)


def softmax_common(x, axis, use_fast_exp):
"""The common part of softmax and fast_softmax"""
shape = x.shape
if axis < 0:
axis = len(shape) + axis
Expand All @@ -57,6 +83,10 @@ def _compute_max(*indices):
eval_range = insert_reduce_index(indices, k1)
return tvm.te.max(x[eval_range], axis=k1)

def _compute_delta(max_elem, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return x[indices] - max_elem[non_reduce_indices]

def _compute_exp(max_elem, *indices):
non_reduce_indices = get_non_reduce_indices(indices)
return te.exp(x[indices] - max_elem[non_reduce_indices])
Expand All @@ -71,7 +101,16 @@ def _normalize(exp, expsum, *indices):

reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = te.compute(reduced_shape, _compute_max, name="T_softmax_maxelem")
exp = te.compute(shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp")

if use_fast_exp:
delta = te.compute(
shape, lambda *indices: _compute_delta(max_elem, *indices), name="T_softmax_delta"
)
exp = topi.math.fast_exp(delta)
else:
exp = te.compute(
shape, lambda *indices: _compute_exp(max_elem, *indices), name="T_softmax_exp"
)
expsum = te.compute(
reduced_shape, lambda *indices: _compute_expsum(exp, *indices), name="T_softmax_expsum"
)
Expand Down
27 changes: 27 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,33 @@ RELAY_REGISTER_OP("nn.softmax")
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_attrs_type<SoftmaxAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);

// relay.fast_softmax
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);

TVM_REGISTER_GLOBAL("relay.op.nn._make.fast_softmax").set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.fast_softmax");
return Call(op, {data}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("nn.fast_softmax")
.describe(R"code(Softmax layer.
Use approximation to compute exponent for faster speed.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
Expand Down
3 changes: 2 additions & 1 deletion src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,8 @@ Array<te::Tensor> VarianceCompute(const Attrs& attrs, const Array<te::Tensor>& i
count -= 1;
}
std::vector<Integer> expand_shape;
auto sq_diff = topi::power(topi::subtract(data, mean), 2);
auto diff = topi::subtract(data, mean);
auto sq_diff = topi::multiply(diff, diff);
if (param->exclude) {
axes = GetExcludeAxes(sq_diff->shape.size(), param->axis);
ICHECK_NE(axes.size(), 0);
Expand Down
10 changes: 9 additions & 1 deletion tests/python/topi/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def verify(from_dtype, to_dtype, low=-100, high=100):

def test_fastmath():
def test_apply(func, name, f_numpy, low, high, step, dtype="float32"):
a_np = np.arange(low, high, step).astype(dtype)
a_np = np.arange(low, high, step).astype(dtype).reshape((1, -1))
b_np = f_numpy(a_np)
A = te.placeholder(a_np.shape, dtype=dtype, name="A")
B = func(A)
Expand All @@ -224,6 +224,14 @@ def check_device(device):
test_apply(topi.fast_exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
test_apply(topi.fast_erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
test_apply(topi.fast_tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)
test_apply(
topi.nn.fast_softmax,
"fast_softmax",
tvm.topi.testing.softmax_python,
low=-10,
high=10,
step=0.01,
)


if __name__ == "__main__":
Expand Down

0 comments on commit 6ffd740

Please sign in to comment.