Skip to content

Commit

Permalink
fix mixed type backward (apache#18250)
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc authored and sxjscience committed Jul 1, 2020
1 parent 5ba7a77 commit 9ec1c4b
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 10 deletions.
52 changes: 52 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,10 @@ MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1));

MXNET_BINARY_MATH_OP(rminus, b - a);

MXNET_BINARY_MATH_OP_NC(posone, 1);

MXNET_BINARY_MATH_OP_NC(negone, -1);

MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));

template<>
Expand Down Expand Up @@ -795,6 +799,54 @@ struct mod : public mxnet_op::tunable {
}
};

#ifndef _WIN32
struct mixed_mod {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return mod::Map(static_cast<mshadow::half::half_t>(a), b);
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return mod::Map(static_cast<float>(a), b);
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return mod::Map(static_cast<double>(a), b);
}
};

struct mixed_rmod {
template<typename DType,
typename std::enable_if<std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) {
return mod::Map(b, static_cast<mshadow::half::half_t>(a));
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static float Map(DType a, float b) {
return mod::Map(b, static_cast<float>(a));
}

template<typename DType,
typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value ||
std::is_same<DType, float>::value ||
std::is_integral<DType>::value, int>::type = 0>
MSHADOW_XINLINE static double Map(DType a, double b) {
return mod::Map(b, static_cast<double>(a));
}
};
#endif

struct fmod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
Expand Down
64 changes: 59 additions & 5 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
"FCompute<cpu>",
NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"});

NNVM_REGISTER_OP(_backward_npi_broadcast_add)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::posone,
mshadow_op::posone>);

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
#ifndef _WIN32
Expand All @@ -129,7 +144,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
"FCompute<cpu>",
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"});
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"});

NNVM_REGISTER_OP(_backward_npi_broadcast_sub)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::posone,
mshadow_op::negone>);

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
#ifndef _WIN32
Expand Down Expand Up @@ -159,9 +189,33 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::right,
mshadow_op::left>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::mod>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"});
MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_mod)
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<cpu>",
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::mod, op::mshadow_op::mixed_mod,
op::mshadow_op::mixed_rmod>)
#else
.set_attr<FCompute>(
"FCompute<cpu>",
NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::mod>)
#endif
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mod"});

NNVM_REGISTER_OP(_backward_npi_broadcast_mod)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::mod_grad,
mshadow_op::mod_rgrad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_power)
#ifndef _WIN32
Expand Down
23 changes: 22 additions & 1 deletion src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ NNVM_REGISTER_OP(_npi_add)
NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_add)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::posone,
mshadow_op::posone>);

NNVM_REGISTER_OP(_npi_subtract)
#ifndef _WIN32
.set_attr<FCompute>(
Expand All @@ -52,6 +56,10 @@ NNVM_REGISTER_OP(_npi_subtract)
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_sub)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::posone,
mshadow_op::negone>);

NNVM_REGISTER_OP(_npi_multiply)
#ifndef _WIN32
.set_attr<FCompute>(
Expand All @@ -69,7 +77,20 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul)
mshadow_op::left>);

NNVM_REGISTER_OP(_npi_mod)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>);
#ifndef _WIN32
.set_attr<FCompute>(
"FCompute<gpu>",
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::mod, op::mshadow_op::mixed_mod,
op::mshadow_op::mixed_rmod>);
#else
.set_attr<FCompute>(
"FCompute<gpu>",
NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::mod>);
#endif

NNVM_REGISTER_OP(_backward_npi_broadcast_mod)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::mod_grad,
mshadow_op::mod_rgrad>);

NNVM_REGISTER_OP(_npi_power)
#ifndef _WIN32
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_true_divide-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <vector>
#include "../../common/utils.h"
#include "../tensor/elemwise_binary_broadcast_op.h"
#include "../numpy/np_elemwise_broadcast_op.h"

namespace mxnet {
namespace op {
Expand Down
18 changes: 17 additions & 1 deletion src/operator/numpy/np_true_divide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,26 @@ NNVM_REGISTER_OP(_npi_true_divide)
})
#endif
.set_attr<FCompute>("FCompute<cpu>", TrueDivideBroadcastCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"})
.add_argument("lhs", "NDArray-or-Symbol", "Dividend array")
.add_argument("rhs", "NDArray-or-Symbol", "Divisor array");


NNVM_REGISTER_OP(_backward_npi_broadcast_div)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyBinaryBackwardUseIn<cpu, mshadow_op::div_grad,
mshadow_op::div_rgrad>);

NNVM_REGISTER_OP(_npi_true_divide_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
Expand Down
4 changes: 4 additions & 0 deletions src/operator/numpy/np_true_divide.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ namespace op {
NNVM_REGISTER_OP(_npi_true_divide)
.set_attr<FCompute>("FCompute<gpu>", TrueDivideBroadcastCompute<gpu>);

NNVM_REGISTER_OP(_backward_npi_broadcast_div)
.set_attr<FCompute>("FCompute<gpu>", NumpyBinaryBackwardUseIn<gpu, mshadow_op::div_grad,
mshadow_op::div_rgrad>);

NNVM_REGISTER_OP(_npi_true_divide_scalar)
.set_attr<FCompute>("FCompute<gpu>", TrueDivideScalarCompute<gpu, mshadow_op::true_divide>);

Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::posone); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::negone); // NOLINT()
/*!
* \brief Tuner objects, *not* automatically generated
*/
Expand Down
9 changes: 6 additions & 3 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2668,10 +2668,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
use_broadcast=False, equal_nan=True)

funcs = {
'add': (-1.0, 1.0, None, None),
'subtract': (-1.0, 1.0, None, None),
'add': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape),
lambda y, x1, x2: _np.ones(y.shape)),
'subtract': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape),
lambda y, x1, x2: _np.ones(y.shape) * -1),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape),
lambda y, x1, x2: _np.broadcast_to(x1, y.shape)),
'mod': (1.0, 5.0, None, None),
'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2,
lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)),
}
Expand Down Expand Up @@ -2699,7 +2702,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)

if func == 'subtract':
if func == 'subtract' or func == 'mod':
continue
for type1, type2 in itertools.product(itypes, itypes):
if type1 == type2:
Expand Down

0 comments on commit 9ec1c4b

Please sign in to comment.