Skip to content

Commit

Permalink
[RELAY] Add softmax (apache#1841)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and tqchen committed Oct 7, 2018
1 parent a7e8046 commit 0dbc8a9
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 3 deletions.
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ This level enables fully connected multi-layer perceptron.
tvm.relay.sigmoid
tvm.relay.add
tvm.relay.expand_dims
tvm.relay.nn.softmax

**Level 2: Convolutions**

Expand Down
10 changes: 10 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ struct ConvAttrs : public tvm::AttrsNode<ConvAttrs> {
}
};

/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;

TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") {
TVM_ATTR_FIELD(axis).set_default(1)
.describe("The axis to sum over when computing softmax.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
20 changes: 20 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,23 @@ def conv2d(data,
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
weight_layout, out_layout, out_dtype)


def softmax(data, axis):
r"""Computes softmax.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
Parameters
----------
data: relay.Expr
The input data to the operator.
axis: int
The axis to sum over when computing softmax
"""

return _make.softmax(data, axis)
6 changes: 3 additions & 3 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ bool Conv2DRel(const Array<Type>& types,
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape(
{param->channels / param->groups,
data->shape[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
data->shape[1] / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
wshape[kernel_layout.indexof('O')] *= param->groups;
channels = param->channels;
Expand Down
43 changes: 43 additions & 0 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*!
* Copyright (c) 2018 by Contributors
* \file nn.cc
* \brief Property def of nn operators.
*/

#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include "../type_relations.h"

namespace tvm {
namespace relay {


TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
auto attrs = make_node<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
};

runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
});

RELAY_REGISTER_OP("nn.softmax")
.describe(R"code(Softmax layer.
.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)}
.. note::
This operator can be optimized away for inference.
- **data**: The input data
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(1)
.add_type_rel("Identity", IdentityRel);

} // namespace relay
} // namespace tvm
14 changes: 14 additions & 0 deletions tests/python/relay/test_op_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ def test_expand_dims_infer_type():
(n, t, 1, 100), "float32")


def test_softmax():
ib = relay.ir_builder.IRBuilder()
n, d = tvm.var("n"), tvm.var("d")
x = ib.param("x", relay.ty.TensorType((n, d), "float32"))
with ib.function(x) as func:
ib.ret(relay.nn.softmax(x, axis=1))
ib.ret(func)

func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type()
assert ftype.ret_type == relay.ty.TensorType((n, d), "float32")


def test_unary_op():
for op in [relay.exp,
relay.log,
Expand All @@ -34,3 +47,4 @@ def test_unary_op():
if __name__ == "__main__":
test_expand_dims_infer_type()
test_unary_op()
test_softmax()

0 comments on commit 0dbc8a9

Please sign in to comment.