Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix mixed type backward
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed May 6, 2020
1 parent 35916f1 commit cd28e8e
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,10 @@ MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));

MXNET_BINARY_MATH_OP(rminus, b - a);

MXNET_BINARY_MATH_OP_NC(self_grad, 1);

MXNET_BINARY_MATH_OP_NC(minus_rgrad, -1);

MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));

MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b));
Expand Down
34 changes: 32 additions & 2 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
"FCompute<cpu>",
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"});

NNVM_REGISTER_OP(_backward_npi_broadcast_add)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::self_grad,
mshadow_op::self_grad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
#ifndef _WIN32
Expand All @@ -128,7 +143,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
"FCompute<cpu>",
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"});

NNVM_REGISTER_OP(_backward_npi_broadcast_sub)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::self_grad,
mshadow_op::minus_rgrad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
#ifndef _WIN32
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_true_divide-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <vector>
#include "../../common/utils.h"
#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../numpy/np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {
Expand Down
18 changes: 17 additions & 1 deletion src/operator/numpy/np_true_divide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,26 @@ NNVM_REGISTER_OP(_npi_true_divide)
})
#endif
.set_attr<FCompute>("FCompute<cpu>", TrueDivideBroadcastCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"})
.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");


NNVM_REGISTER_OP(_backward_npi_broadcast_div)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::div_grad,
mshadow_op::div_rgrad>);

NNVM_REGISTER_OP(_npi_true_divide_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::self_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus_rgrad); // NOLINT()
/*!
* \brief Tuner objects, *not* automatically generated
*/
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2584,8 +2584,10 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

funcs = {
'add': (-1.0, 1.0, None, None),
'subtract': (-1.0, 1.0, None, None),
'add': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape),
lambda y, x1, x2: _np.ones(y.shape)),
'subtract': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape),
lambda y, x1, x2: _np.ones(y.shape) * -1),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape),
lambda y, x1, x2: _np.broadcast_to(x1, y.shape)),
'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2,
Expand Down

0 comments on commit cd28e8e

Please sign in to comment.