Skip to content

Commit

Permalink
[NNVM][TOPI] Add gradients for broadcast_* ops (apache#1234)
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes authored and tqchen committed Jun 27, 2018
1 parent 3f7cce3 commit 44d8203
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 98 deletions.
9 changes: 8 additions & 1 deletion nnvm/include/nnvm/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,19 @@ class Tuple {
inline Tuple(std::initializer_list<ValueType> init) {
this->assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init the vector
*/
inline Tuple(std::vector<ValueType> init) { // NOLINT(runtime/explicit)
this->assign(init.begin(), init.end());
}
/*!
* \brief move constructor from Tuple
* \param src the source shape
*/

inline Tuple(Tuple<ValueType>&& src) { // NOLINT(*)
inline Tuple(Tuple<ValueType>&& src) { // NOLINT(runtime/explicit)
this->swap(src);
}
/*!
Expand Down
4 changes: 4 additions & 0 deletions nnvm/python/nnvm/top/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ def _compute(attrs, inputs, out_info):
# min
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)

# collapse sum
reg.register_pattern("collapse_sum", OpPattern.COMM_REDUCE)
reg.register_schedule("collapse_sum", _fschedule_reduce)
61 changes: 57 additions & 4 deletions nnvm/src/top/tensor/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,15 @@ Example::
broadcast_add(x, y) = [[ 1., 1., 1.],
[ 2., 2., 2.]]
)code" NNVM_ADD_FILELINE);
)code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("collapse_sum", n->attrs.name + "_dlhs", { ograds[0], n->inputs[0] }),
MakeNode("collapse_sum", n->attrs.name + "_drhs", { ograds[0], n->inputs[1] })
};
});


NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_sub, subtract)
Expand All @@ -256,7 +264,18 @@ Example::
broadcast_sub(x, y) = [[ 1., 1., 1.],
[ 0., 0., 0.]]
)code" NNVM_ADD_FILELINE);
)code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("collapse_sum", n->attrs.name + "_dlhs", { ograds[0], n->inputs[0] }),
MakeNode("collapse_sum", n->attrs.name + "_drhs", {
MakeNode("negative", n->attrs.name + "_drhs_neg", {ograds[0]}),
n->inputs[1]
})
};
});


NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mul, multiply)
Expand All @@ -273,7 +292,22 @@ Example::
broadcast_mul(x, y) = [[ 0., 0., 0.],
[ 1., 1., 1.]]
)code" NNVM_ADD_FILELINE);
)code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
NodeEntry dlhs = MakeNode("collapse_sum", n->attrs.name + "_dlhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_dlhs_mul",
{ n->inputs[1], ograds[0] }),
n->inputs[0]
});
NodeEntry drhs = MakeNode("collapse_sum", n->attrs.name + "_drhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_drhs_mul",
{ n->inputs[0], ograds[0] }),
n->inputs[1]
});
return std::vector<NodeEntry>{ dlhs, drhs };
});


NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_div, divide)
Expand All @@ -291,7 +325,26 @@ Example::
broadcast_div(x, y) = [[ 3., 3., 3.],
[ 2., 2., 2.]]
)code" NNVM_ADD_FILELINE);
)code" NNVM_ADD_FILELINE)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
NodeEntry dlhs = MakeNode("collapse_sum", n->attrs.name + "_dlhs_sum", {
MakeNode("broadcast_div", n->attrs.name + "_dlhs_div",
{ ograds[0], n->inputs[1] }),
n->inputs[0]
});
NodeEntry dy = MakeNode("broadcast_div", n->attrs.name + "_drhs_div", {
NodeEntry{n, 0, 0},
MakeNode("__mul_scalar__", n->attrs.name + "_rhs_by_two",
{n->inputs[1]}, {{"scalar", "2"}})
});
NodeEntry drhs = MakeNode("collapse_sum", n->attrs.name + "_drhs_sum", {
MakeNode("broadcast_mul", n->attrs.name + "_drhs_mul", { dy, ograds[0] }),
n->inputs[1]
});
return std::vector<NodeEntry>{ dlhs, drhs };
});

} // namespace top
} // namespace nnvm
144 changes: 86 additions & 58 deletions nnvm/src/top/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/compiler/util.h>
#include <nnvm/top/tensor.h>
#include <numeric>
#include "../op_common.h"
#include "../elemwise_op_common.h"
#include "topi/detail/constant_utils.h"
#include "topi/elemwise.h"
#include "topi/reduction.h"
#include "topi/transform.h"

namespace nnvm {
namespace top {
Expand All @@ -21,58 +25,61 @@ using namespace nnvm::compiler;
// reduce
DMLC_REGISTER_PARAMETER(ReduceParam);

inline TShape ReduceShapeImpl(const TShape& ishape,
const TShape& axis,
bool keepdims,
bool exclude) {
inline TShape GetReduceAxes(const uint32_t indim,
const TShape& axis,
bool exclude) {
if (axis.ndim() == 0) {
if (keepdims) {
return TShape(ishape.ndim());
} else {
return TShape(1);
}
TShape r_axes(indim);
std::iota(r_axes.begin(), r_axes.end(), 0);
return r_axes;
}
CHECK_LT(axis[axis.ndim() - 1], ishape.ndim())

CHECK_LT(axis[axis.ndim() - 1], indim)
<< "Reduction axis " << axis[axis.ndim() - 1]
<< " Exceeds input dimensions " << ishape;
<< " exceeds input dimensions " << indim;

TShape in_axis = axis;
for (auto& i : in_axis) {
i = i < 0 ? i + ishape.ndim(): i;
i = i < 0 ? i + indim : i;
CHECK_GE(i, 0) << "axis out of bounds in reduce operator";
CHECK_LT(i, ishape.ndim()) << "axis out of bounds in reduce operator";
CHECK_LT(i, indim) << "axis out of bounds in reduce operator";
}
std::sort(in_axis.begin(), in_axis.end());
if (!exclude) return in_axis;
TShape r_axis(indim - in_axis.ndim());
for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (i == in_axis[j]) {
++j;
continue;
}
r_axis[k++] = i;
}
return r_axis;
}

inline TShape ReduceShapeImpl(const TShape& ishape,
const TShape& axis,
bool keepdims,
bool exclude) {
uint32_t indim = ishape.ndim();
TShape r_axes = GetReduceAxes(indim, axis, exclude);
if (!r_axes.ndim()) return ishape;
if (r_axes.ndim() == indim)
return TShape(keepdims ? indim : 1);

if (keepdims) {
TShape oshape(ishape);
if (exclude) {
for (dim_t i = 0, j = 0; i < ishape.ndim(); ++i) {
if (j < in_axis.ndim() && i == in_axis[j]) {
++j;
continue;
}
oshape[i] = 1;
}
return oshape;
}

for (dim_t i = 0; i < in_axis.ndim(); ++i) {
oshape[in_axis[i]] = 1;
for (unsigned i = 0, j = 0; i < indim; ++i) {
if (i != r_axes[j]) continue;
oshape[i] = 1;
++j;
}
return oshape;
}

if (exclude) {
TShape oshape = TShape(in_axis.ndim());
for (dim_t i = 0; i < in_axis.ndim(); ++i) {
oshape[i] = ishape[in_axis[i]];
}
return oshape;
}
TShape oshape = TShape(std::max<dim_t>(1, ishape.ndim() - in_axis.ndim()));
for (dim_t i = 0, j = 0, k = 0; i < ishape.ndim(); ++i) {
if (j < in_axis.ndim() && i == in_axis[j]) {
TShape oshape(indim - r_axes.ndim());
for (unsigned i = 0, j = 0, k = 0; i < indim; ++i) {
if (i == r_axes[j]) {
++j;
continue;
}
Expand All @@ -95,6 +102,16 @@ inline bool ReduceShape(const nnvm::NodeAttrs& attrs,
return true;
}

inline bool CollapseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
if ((*in_attrs)[0].ndim() == 1) return false;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, (*in_attrs)[1]);
return true;
}

template<typename PType>
inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
PType param;
Expand All @@ -103,18 +120,21 @@ inline void AxesParamParser(nnvm::NodeAttrs* attrs) {
attrs->parsed = std::move(param);
}

#define NNVM_REGISTER_BASE_REDUCE_OP(op) \
NNVM_REGISTER_OP(op) \
.add_arguments(ReduceParam::__FIELDS__()) \
.set_attr_parser(AxesParamParser<ReduceParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_num_outputs(1)

#define NNVM_REGISTER_REDUCE_OP(op) \
NNVM_REGISTER_OP(op) \
NNVM_REGISTER_BASE_REDUCE_OP(op) \
.add_argument("data", "Tensor", "The input") \
.add_arguments(ReduceParam::__FIELDS__()) \
.set_attr_parser(AxesParamParser<ReduceParam>) \
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<ReduceParam>) \
.set_attr<FInferShape>("FInferShape", ReduceShape) \
.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) \
.set_attr<FCorrectLayout>("FCorrectLayout", \
ElemwiseFixedLayoutUnknownOut<1, 1>) \
.set_num_inputs(1) \
.set_num_outputs(1)
.set_num_inputs(1)

NNVM_REGISTER_REDUCE_OP(sum)
.describe(R"code(Computes the sum of array elements over given axes.
Expand All @@ -139,20 +159,10 @@ Example::
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
Array<Expr> axis;
if (param.exclude) {
std::set<dim_t> exclude_axis;
for (dim_t i = 0; i < param.axis.ndim(); ++i) {
exclude_axis.insert(param.axis[i]);
}
for (dim_t i = 0; i < static_cast<int>(inputs[0].ndim()); ++i) {
if (exclude_axis.count(i) == 0) {
axis.push_back(make_const(Int(32), i));
}
}
} else {
axis = ShapeToArray(param.axis);
}
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::sum(inputs[0], axis, param.keepdims) };
})
Expand All @@ -178,7 +188,9 @@ NNVM_REGISTER_REDUCE_OP(max)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::max(inputs[0], axis, param.keepdims) };
})
Expand Down Expand Up @@ -210,7 +222,9 @@ NNVM_REGISTER_REDUCE_OP(min)
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
auto axis = ShapeToArray(param.axis);
TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
param.axis, param.exclude);
auto axis = ShapeToArray(r_axes);
return Array<Tensor>{
topi::min(inputs[0], axis, param.keepdims) };
})
Expand All @@ -233,6 +247,20 @@ NNVM_REGISTER_REDUCE_OP(min)
};
});

NNVM_REGISTER_BASE_REDUCE_OP(collapse_sum)
.add_argument("data", "Tensor", "The input")
.add_argument("as", "Tensor", "The reference")
.set_attr<FInferShape>("FInferShape", CollapseShape)
.set_attr<FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseFixedLayoutUnknownOut<2, 1>)
.set_num_inputs(2)
.describe(R"code(Reduces lhs to the shape of rhs via sum)code" NNVM_ADD_FILELINE)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const NodeAttrs& attrs,
const Array<Tensor>& inputs,
const Array<Tensor>& out_info) {
return Array<Tensor>{ topi::collapse_sum(inputs[0], inputs[1]->shape) };
});

} // namespace top
} // namespace nnvm
4 changes: 1 addition & 3 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
const ExpandDimsParam& param = nnvm::get<ExpandDimsParam>(n->attrs.parsed);
return std::vector<NodeEntry> {
MakeNode("sum", n->attrs.name + "_grad", {ograds[0]},
{{"axis", std::to_string(param.axis)}})
MakeNode("collapse_sum", n->attrs.name + "_grad", {ograds[0], n->inputs[0]})
};
})
.set_support_level(1);
Expand Down
Loading

0 comments on commit 44d8203

Please sign in to comment.