From 9ec1c4b7ccdb908be6fe39c17c7fd17ee9f2d708 Mon Sep 17 00:00:00 2001 From: Yijun Chen Date: Sun, 10 May 2020 03:20:47 +0800 Subject: [PATCH] fix mixed type backward (#18250) --- src/operator/mshadow_op.h | 52 +++++++++++++++ .../numpy/np_elemwise_broadcast_op.cc | 64 +++++++++++++++++-- .../numpy/np_elemwise_broadcast_op.cu | 23 ++++++- src/operator/numpy/np_true_divide-inl.h | 1 + src/operator/numpy/np_true_divide.cc | 18 +++++- src/operator/numpy/np_true_divide.cu | 4 ++ src/operator/operator_tune.cc | 2 + tests/python/unittest/test_numpy_op.py | 9 ++- 8 files changed, 163 insertions(+), 10 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 4d9de29ce709..59707f89dd19 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -728,6 +728,10 @@ MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1)); MXNET_BINARY_MATH_OP(rminus, b - a); +MXNET_BINARY_MATH_OP_NC(posone, 1); + +MXNET_BINARY_MATH_OP_NC(negone, -1); + MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b)); template<> @@ -795,6 +799,54 @@ struct mod : public mxnet_op::tunable { } }; +#ifndef _WIN32 +struct mixed_mod { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return mod::Map(static_cast(a), b); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return mod::Map(static_cast(a), b); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return mod::Map(static_cast(a), b); + } +}; + +struct mixed_rmod { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return mod::Map(b, static_cast(a)); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return mod::Map(b, static_cast(a)); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return mod::Map(b, static_cast(a)); + } +}; +#endif + struct fmod : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 0ee677adbc13..bdf25a8508dd 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -116,7 +116,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) "FCompute", NumpyBinaryBroadcastComputeWithBool) #endif -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_add) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) #ifndef _WIN32 @@ -129,7 +144,22 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) "FCompute", NumpyBinaryBroadcastCompute) #endif -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_sub) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) #ifndef _WIN32 @@ -159,9 +189,33 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_attr("FCompute", NumpyBinaryBackwardUseIn); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_mod) +#ifndef _WIN32 +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute) +#else +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute) +#endif +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mod"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_mod) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_power) #ifndef _WIN32 diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 1e0130494469..8a13b42e4846 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -40,6 +40,10 @@ NNVM_REGISTER_OP(_npi_add) NumpyBinaryBroadcastComputeWithBool); #endif +NNVM_REGISTER_OP(_backward_npi_broadcast_add) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_subtract) #ifndef _WIN32 .set_attr( @@ -52,6 +56,10 @@ NNVM_REGISTER_OP(_npi_subtract) NumpyBinaryBroadcastCompute); #endif +NNVM_REGISTER_OP(_backward_npi_broadcast_sub) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_multiply) #ifndef _WIN32 .set_attr( @@ -69,7 +77,20 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul) mshadow_op::left>); NNVM_REGISTER_OP(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute); +#ifndef _WIN32 +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute); +#else +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute); +#endif + +NNVM_REGISTER_OP(_backward_npi_broadcast_mod) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_power) #ifndef _WIN32 diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index be2ce51506a1..538a026b6b8e 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -29,6 +29,7 @@ #include #include "../../common/utils.h" #include "../tensor/elemwise_binary_broadcast_op.h" +#include "../numpy/np_elemwise_broadcast_op.h" namespace mxnet { namespace op { diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 6edfb4dd0901..f2529b348a2c 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -81,10 +81,26 @@ NNVM_REGISTER_OP(_npi_true_divide) }) #endif .set_attr("FCompute", TrueDivideBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") .add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); + +NNVM_REGISTER_OP(_backward_npi_broadcast_div) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index 7211f4a0a006..c8eccfe140b4 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -31,6 +31,10 @@ namespace op { NNVM_REGISTER_OP(_npi_true_divide) .set_attr("FCompute", TrueDivideBroadcastCompute); +NNVM_REGISTER_OP(_backward_npi_broadcast_div) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index b76e341b9fc6..20bb4bb98322 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -425,6 +425,8 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::posone); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::negone); // NOLINT() /*! * \brief Tuner objects, *not* automatically generated */ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index f57313c94276..88917aac0aca 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2668,10 +2668,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) funcs = { - 'add': (-1.0, 1.0, None, None), - 'subtract': (-1.0, 1.0, None, None), + 'add': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.ones(y.shape)), + 'subtract': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.ones(y.shape) * -1), 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape), lambda y, x1, x2: _np.broadcast_to(x1, y.shape)), + 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), } @@ -2699,7 +2702,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): continue check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) - if func == 'subtract': + if func == 'subtract' or func == 'mod': continue for type1, type2 in itertools.product(itypes, itypes): if type1 == type2: