From 4c41afd1e6014c5cde00f4d253474ffa1e141cac Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 26 Nov 2019 09:03:15 -0800 Subject: [PATCH] Backport of #16827, #16791 and #16888 to 1.6 branch (#16901) * refactor and reduce float types for some functions, also add bitwise_xor (#16827) * Mixed precison binary op backward (use in) for numpy (#16791) * mixed precison binary op backward * reduce unix cpu runtime * Add evaluation_loss to the estimator base class. (#16888) * Add evaluation_loss to the estimator base class. * Update the base estimator class to support the separate evaluation loss. * Add evaluation loss to the base estimator class. * Add unittest for evaluation loss in the test_evaluation function * Update estimator.py * Update estimator.py --- .../gluon/contrib/estimator/estimator.py | 11 +- python/mxnet/ndarray/numpy/_op.py | 40 ++- python/mxnet/numpy/multiarray.py | 42 ++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 35 +- src/operator/elemwise_op_common.h | 3 +- .../numpy/np_elemwise_broadcast_op.cc | 243 +------------- .../numpy/np_elemwise_broadcast_op.cu | 75 +---- src/operator/numpy/np_elemwise_broadcast_op.h | 114 ++++++- .../np_elemwise_broadcast_op_extended.cc | 305 ++++++++++++++++++ .../np_elemwise_broadcast_op_extended.cu | 108 +++++++ src/operator/operator_tune.cc | 4 +- .../tensor/elemwise_binary_broadcast_op.h | 136 +++++--- src/operator/tensor/elemwise_binary_op.h | 148 +++++---- .../tensor/elemwise_binary_scalar_op.h | 20 ++ src/operator/tensor/elemwise_unary_op.h | 4 +- tests/python/unittest/test_gluon_estimator.py | 4 +- .../unittest/test_numpy_interoperability.py | 13 + tests/python/unittest/test_numpy_op.py | 23 +- 19 files changed, 892 insertions(+), 437 deletions(-) create mode 100644 src/operator/numpy/np_elemwise_broadcast_op_extended.cc create mode 100644 src/operator/numpy/np_elemwise_broadcast_op_extended.cu diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 83b954d02e10..54a0b165016e 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -59,6 +59,9 @@ class Estimator(object): Trainer to apply optimizer on network parameters. context : Context or list of Context Device(s) to run the training on. + evaluation_loss: gluon.loss.loss + Loss (objective) function to calculate during evaluation. If set evaluation_loss + None, it will use the same loss function as self.loss """ @@ -85,12 +88,16 @@ def __init__(self, net, metrics=None, initializer=None, trainer=None, - context=None): + context=None, + evaluation_loss=None): self.net = net self.loss = self._check_loss(loss) self._train_metrics = _check_metrics(metrics) self._add_default_training_metrics() self._add_validation_metrics() + self.evaluation_loss = self.loss + if evaluation_loss is not None: + self.evaluation_loss = self._check_loss(evaluation_loss) self.logger = logging.Logger(name='Estimator', level=logging.INFO) self.logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -228,7 +235,7 @@ def evaluate_batch(self, """ data, label = self._get_data_and_label(val_batch, self.context, batch_axis) pred = [self.net(x) for x in data] - loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)] + loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)] # update metrics for metric in val_metrics: if isinstance(metric, metric_loss): diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index ff404a7a2df7..ed3d9d8e0695 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -36,7 +36,7 @@ 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', - 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', + 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff'] @@ -4289,6 +4289,44 @@ def hypot(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def bitwise_xor(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise XOR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : ndarray or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.bitwise_xor(13, 17) + 28 + + >>> np.bitwise_xor(31, 5) + 26 + >>> np.bitwise_xor(np.array([31,3], dtype='int32'), 5) + array([26, 6]) + + >>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32')) + array([26, 5]) + >>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool')) + array([ True, False]) + """ + return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def ldexp(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index c623f67967ba..ad5fb5444ee3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -53,8 +53,8 @@ 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', - 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', - 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', + 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', + 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff'] @@ -6194,6 +6194,44 @@ def hypot(x1, x2, out=None, **kwargs): return _mx_nd_np.hypot(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def bitwise_xor(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise XOR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : ndarray or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.bitwise_xor(13, 17) + 28 + + >>> np.bitwise_xor(31, 5) + 26 + >>> np.bitwise_xor(np.array([31,3], dtype=np.int32), 5) + array([26, 6]) + + >>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32')) + array([26, 5]) + >>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool')) + array([ True, False]) + """ + return _mx_nd_np.bitwise_xor(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def ldexp(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index cdd21af829de..8a4a90cb4a7a 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -222,6 +222,7 @@ def _register_array_function(): 'ceil', 'trunc', 'floor', + 'bitwise_xor', 'logical_not', 'equal', 'not_equal', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index d3837d2bd1dd..e4ac4628c98a 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -38,7 +38,7 @@ 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', - 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', + 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff'] @@ -4057,17 +4057,16 @@ def hypot(x1, x2, out=None, **kwargs): Parameters ---------- - x1, x2 : array_like + x1, x2 : _Symbol or scalar Leg of the triangle(s). - out : ndarray, None, or tuple of ndarray and None, optional + out : _Symbol or None, optional A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or `None`, - a freshly-allocated array is returned. A tuple (possible only as a - keyword argument) must have length equal to the number of outputs. + a freshly-allocated array is returned. Returns ------- - z : ndarray + z : _Symbol or scalar The hypotenuse of the triangle(s). This is a scalar if both `x1` and `x2` are scalars. @@ -4079,6 +4078,30 @@ def hypot(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def bitwise_xor(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise XOR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : _Symbol or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : _Symbol or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : _Symbol or scalar + Result. + """ + return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out) + + @set_module('mxnet.symbol.numpy') def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 6711297718b2..2cdd73a95801 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -209,7 +209,8 @@ inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs, CHECK(in_attrs->at(0) == mshadow::kInt64 || in_attrs->at(0) == mshadow::kInt32 || in_attrs->at(0) == mshadow::kInt8 || - in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types."; + in_attrs->at(0) == mshadow::kUint8 || + in_attrs->at(0) == mshadow::kBool) << "Only supports integer types."; if (n_in != -1) { CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; } diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index a76e59d30dc6..f2adfc125d02 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -28,16 +28,6 @@ namespace mxnet { namespace op { -bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return in_attrs->at(0) != -1; -} - #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ @@ -147,22 +137,9 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) "FCompute", NumpyBinaryBroadcastComputeWithBool) #endif -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); - -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); - -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); - -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) -.describe(R"code()code" ADD_FILELINE) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); -NNVM_REGISTER_OP(_backward_npi_copysign) +NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) @@ -174,44 +151,16 @@ NNVM_REGISTER_OP(_backward_npi_copysign) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +.set_attr("FCompute", NumpyBinaryBackwardUseIn); -NNVM_REGISTER_OP(_npi_lcm) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", -[](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; -}) -.set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", ElemwiseIntType<2, 1>) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - return std::vector >{{0, 0}, {1, 0}}; -}) -.set_attr("FGradient", MakeZeroGradNodes) -.set_attr("FCompute", BinaryBroadcastCompute) -.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") -.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); - -NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInferType", ElemwiseIntType<1, 1>) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; - }) -.set_attr("FGradient", MakeZeroGradNodes) -.add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") -.set_attr("FCompute", BinaryScalarOp::Compute); +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -244,177 +193,5 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -inline bool IsFloatType(const int dtype) { - return (dtype == mshadow::kFloat16 || - dtype == mshadow::kFloat32 || - dtype == mshadow::kFloat64); -} - -inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - // check if it is float16, float32 or float64. If not, raise error. - CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n"; - return out_attrs->at(0) != -1; -} - -NNVM_REGISTER_OP(_npi_arctan2) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"x1", "x2"}; - }) -.set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", Arctan2OpType) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"}) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}}; - }) -.add_argument("x1", "NDArray-or-Symbol", "The input array") -.add_argument("x2", "NDArray-or-Symbol", "The input array"); - -NNVM_REGISTER_OP(_backward_npi_arctan2) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -bool HypotOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - - CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n"; - return out_attrs->at(0) != -1; -} - -// rigister hypot that do not support int here -NNVM_REGISTER_OP(_npi_hypot) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"x1", "x2"}; - }) -.set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", HypotOpType) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"}) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) -.add_argument("x1", "NDArray-or-Symbol", "The input array") -.add_argument("x2", "NDArray-or-Symbol", "The input array"); - -NNVM_REGISTER_OP(_backward_npi_hypot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector > {{0, 1}}; - }) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"}); - -NNVM_REGISTER_OP(_backward_npi_ldexp) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 1}}; - }) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", BinaryScalarOp::Backward); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", BinaryScalarOp::Backward); - } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a0a277df211f..59dfc25db963 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -64,35 +64,16 @@ NNVM_REGISTER_OP(_npi_multiply) NumpyBinaryBroadcastComputeWithBool); #endif +NNVM_REGISTER_OP(_backward_npi_broadcast_mul) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_mod) .set_attr("FCompute", BinaryBroadcastCompute); NNVM_REGISTER_OP(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_npi_copysign) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_npi_lcm) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_backward_npi_copysign) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -NNVM_REGISTER_OP(_npi_arctan2) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_backward_npi_arctan2) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); -NNVM_REGISTER_OP(_npi_hypot) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_backward_npi_hypot) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -117,53 +98,5 @@ NNVM_REGISTER_OP(_npi_power_scalar) NNVM_REGISTER_OP(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_npi_copysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_rcopysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_copysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -NNVM_REGISTER_OP(_backward_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_backward_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); - } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 1a4596fba91c..a2b7877dc444 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -25,6 +25,7 @@ #ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ #define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ +#include #include #include @@ -34,6 +35,16 @@ namespace mxnet { namespace op { +inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return in_attrs->at(0) != -1; +} + inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { LOG(FATAL) << "Operator " << op_name << " does not support combination of " << common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2) @@ -381,11 +392,13 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, } template -void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, +void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); @@ -396,7 +409,104 @@ void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, return; } - PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + const TBlob& ograd = inputs[0]; + const TBlob& lgrad = outputs[0]; + const TBlob& rgrad = outputs[1]; + + if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + // If any of the inputs is a float, it's the same type as the output + // So 2 of the 3 tensors have the same data type + Stream *s = ctx.get_stream(); + mxnet::TShape new_lshape, new_rshape, new_oshape; + using namespace broadcast; + const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_, + &new_lshape, &new_rshape, &new_oshape) != 0; + + // Prepare all the temporary memory + size_t workspace_size_l = 0, workspace_size_r = 0; + TBlob temp_tblob; // The TBlob for casted input data + TBlob temp_igrad; // The TBlob for casted grad results + size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size(); + Tensor workspace; + + MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, { + workspace_size_l = ReduceWorkspaceSize( + s, new_lshape, req[0], new_oshape, new_lshape, new_rshape); + workspace_size_r = ReduceWorkspaceSize( + s, new_rshape, req[1], new_oshape, new_lshape, new_rshape); + }); + size_t workspace_size = std::max(workspace_size_l, workspace_size_r); + size_t cast_tensor_size = tensor_size * sizeof(OType); + // Allocate the temporary memories now + Tensor temp_space = + ctx.requested[0].get_space_typed( + Shape1(workspace_size + cast_tensor_size * 2), s); + // Tensor for temp_tblob + Tensor temp_tblob_tensor( + reinterpret_cast(temp_space.dptr_), + Shape1(tensor_size), s); + // Tensor for temp_igrad + Tensor temp_igrad_tensor( + reinterpret_cast(temp_space.dptr_) + tensor_size, + Shape1(tensor_size), s); + temp_tblob = + TBlob(temp_tblob_tensor) + .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_)); + temp_igrad = + TBlob(temp_igrad_tensor) + .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_)); + if (temp_igrad.Size() != 0) { + Kernel::Launch(s, temp_igrad.Size(), temp_igrad.dptr()); + } + workspace = + Tensor(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s); + }); + // Cast the input that does not have consistent type to temp_tblob + CastCompute( + attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob}); + if (!need_bc) { + if (lhs.type_flag_ != ograd.type_flag_) { + ElemwiseBinaryOp::BackwardUseIn( + attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad}); + } else { + ElemwiseBinaryOp::BackwardUseIn( + attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad}); + } + } else { + if (lhs.type_flag_ != ograd.type_flag_) { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, { + BinaryBroadcastBackwardUseInImplWithWorkspace( + ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad}, + workspace, new_lshape, new_rshape, new_oshape); + }); + }); + } else { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, { + BinaryBroadcastBackwardUseInImplWithWorkspace( + ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad}, + workspace, new_lshape, new_rshape, new_oshape); + }); + }); + } + } + + // If both inputs are floating numbers, cast the igrad to the input that has + // the different data type + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ != ograd.type_flag_) { + CastCompute(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad}); + } else { + CastCompute(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad}); + } + } + } else { + // Case where both inputs are integer types, should not even do + // backward computation for this case. + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + } } } // namespace op diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc new file mode 100644 index 000000000000..84c47e597883 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_binary_op_extended.cc + * \brief CPU Implementation of extended functions for elementwise numpy binary broadcast operator. + */ + +#include "../../common/utils.h" +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_argument("scalar", "float", "scalar input") + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) +.describe(R"code()code" ADD_FILELINE) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"}); + +NNVM_REGISTER_OP(_backward_npi_copysign) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_lcm) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastIntCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_bitwise_xor) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastIntCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + // check if it is float16, float32 or float64. If not, raise error. + CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as input.\n"; + return out_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_npi_arctan2) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", Arctan2OpType) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("x1", "NDArray-or-Symbol", "The input array") +.add_argument("x2", "NDArray-or-Symbol", "The input array"); + +NNVM_REGISTER_OP(_backward_npi_arctan2) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +bool HypotOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + + CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as input.\n"; + return out_attrs->at(0) != -1; +} + +// rigister hypot that do not support int here +NNVM_REGISTER_OP(_npi_hypot) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", HypotOpType) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}, {1, 0}}; + }) +.add_argument("x1", "NDArray-or-Symbol", "The input array") +.add_argument("x2", "NDArray-or-Symbol", "The input array"); + +NNVM_REGISTER_OP(_backward_npi_hypot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector > {{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"}); + +NNVM_REGISTER_OP(_backward_npi_ldexp) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu new file mode 100644 index 000000000000..f858fb4a4e79 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_broadcast_op_extended.cu + * \brief GPU Implementation of extended functions for elementwise binary broadcast operator. + */ + +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_copysign) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_npi_lcm) +.set_attr("FCompute", BinaryBroadcastIntCompute); + +NNVM_REGISTER_OP(_npi_bitwise_xor) +.set_attr("FCompute", BinaryBroadcastIntCompute); + +NNVM_REGISTER_OP(_backward_npi_copysign) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_arctan2) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_arctan2) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_hypot) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_hypot) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_copysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rcopysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_copysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + +NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + +NNVM_REGISTER_OP(_npi_ldexp) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_npi_ldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_ldexp) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_backward_npi_ldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_backward_npi_rldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 633f63026bc0..e2a4c8af3099 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -396,10 +396,10 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index b48ed389ba98..ffd0f123070a 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -332,6 +332,37 @@ struct csr_dns_map_kernel { } // namespace mxnet_op +template +void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeInt(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; + } + MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); + }); + }); + } +} + template void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -345,22 +376,21 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, if (!ndim) { ElemwiseBinaryOp::Compute(attrs, ctx, inputs, req, outputs); } else { - if (req[0] != kNullOp) { - mshadow::Stream *s = ctx.get_stream(); - if (outputs[0].type_flag_ == mshadow::kBool) { - LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; - } - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); - }); - }); + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); + }); + }); } } @@ -377,19 +407,18 @@ void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, if (!ndim) { ElemwiseBinaryOp::ComputeWithBool(attrs, ctx, inputs, req, outputs); } else { - if (req[0] != kNullOp) { - mshadow::Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); - }); + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); }); - } + }); } } @@ -406,20 +435,19 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, if (!ndim) { ElemwiseBinaryOp::ComputeLogic(attrs, ctx, inputs, req, outputs); } else { - if (req[0] != kNullOp) { - mshadow::Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - outputs[0].dptr()); - }); - }); - } + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + outputs[0].dptr()); + }); + }); } } @@ -671,6 +699,32 @@ BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs); +template +void BinaryBroadcastBackwardUseInImplWithWorkspace(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mshadow::Tensor& workspace, + const mxnet::TShape& new_lshape, + const mxnet::TShape& new_rshape, + const mxnet::TShape& new_oshape) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace broadcast; + Stream *s = ctx.get_stream(); + const TBlob lgrad = outputs[0].reshape(new_lshape); + const TBlob rgrad = outputs[1].reshape(new_rshape); + const TBlob ograd = inputs[0].reshape(new_oshape); + const TBlob lhs = inputs[1].reshape(new_lshape); + const TBlob rhs = inputs[2].reshape(new_rshape); + if (ograd.Size() != 0) { + Reduce(s, lgrad, req[0], workspace, + ograd, lhs, rhs); + Reduce(s, rgrad, req[1], workspace, + ograd, lhs, rhs); + } +} + template inline void BinaryBroadcastBackwardUseInImpl(const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index c046a28f16b2..bc5140a5d75f 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -474,6 +474,30 @@ class ElemwiseBinaryOp : public OpBase { std::vector *in_attrs, std::vector *out_attrs); + template + static void ComputeInt(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); + } + template static void Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -481,25 +505,24 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - if (outputs[0].type_flag_ == mshadow::kBool) { - LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; - } - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); } template @@ -509,22 +532,21 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } }); - } + }); } template @@ -534,23 +556,22 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), - inputs[1].dptr()); - } - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), + inputs[1].dptr()); + } }); - } + }); } template @@ -560,22 +581,21 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } }); - } + }); } template diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 834bbdbfc3d1..3e8702813a7c 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -244,6 +244,26 @@ class BinaryScalarOp : public UnaryOp { }); } + template + static void ComputeInt(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + DCHECK_EQ(inputs.size(), 1); + DCHECK_EQ(outputs.size(), 1); + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + const double alpha = nnvm::get(attrs.parsed); + MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + }); + }); + } + template static void ComputeLogic(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 188ccd68a340..8886e15e3972 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -453,8 +453,8 @@ void CastCompute(const nnvm::NodeAttrs& attrs, Tensor out = outputs[0].FlatTo1D(s); MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, SrcDType, { Tensor data = inputs[0].FlatTo1D(s); - if (outputs[0].type_flag_ != inputs[0].type_flag_ || - req[0] != kWriteInplace) { + if ((outputs[0].type_flag_ != inputs[0].type_flag_ || + req[0] != kWriteInplace) && outputs[0].Size() != 0) { Assign(out, req[0], tcast(data)); } }); diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index aaf9839b29f3..cf913a6161c0 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -83,13 +83,15 @@ def test_validation(): ctx = mx.cpu() loss = gluon.loss.L2Loss() acc = mx.metric.Accuracy() + evaluation_loss = gluon.loss.L1Loss() net.initialize(ctx=ctx) trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) est = Estimator(net=net, loss=loss, metrics=acc, trainer=trainer, - context=ctx) + context=ctx, + evaluation_loss=evaluation_loss) # Input dataloader est.fit(train_data=dataloader, val_data=dataloader, diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 8416b1a9099f..5b6cea70c28d 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -788,6 +788,18 @@ def _add_workload_lcm(): OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) +def _add_workload_bitwise_xor(): + OpArgMngr.add_workload('bitwise_xor', np.array([False, False, True, True], dtype=np.bool), + np.array([False, True, False, True], dtype=np.bool)) + for dtype in [np.int8, np.int32, np.int64]: + zeros = np.array([0], dtype=dtype) + ones = np.array([-1], dtype=dtype) + OpArgMngr.add_workload('bitwise_xor', zeros, zeros) + OpArgMngr.add_workload('bitwise_xor', ones, zeros) + OpArgMngr.add_workload('bitwise_xor', zeros, ones) + OpArgMngr.add_workload('bitwise_xor', ones, ones) + + def _add_workload_ldexp(): OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int8)) OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, np.int8)) @@ -1194,6 +1206,7 @@ def _prepare_workloads(): _add_workload_inner() _add_workload_hypot() _add_workload_lcm() + _add_workload_bitwise_xor() _add_workload_ldexp() _add_workload_subtract(array_pool) _add_workload_multiply(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 643b9c1e9ba0..3aef5ca6c3b2 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1649,6 +1649,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'power': (1.0, 2.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2], [lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]), 'lcm': (-100, 100, [None], None, [[_np.int32]]), + 'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]), 'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), 'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)], @@ -1683,7 +1684,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): @with_seed() @use_np def test_np_mixed_precision_binary_funcs(): - def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype): + itypes = [np.bool, np.int8, np.int32, np.int64] + ftypes = [np.float16, np.float32, np.float64] + def check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, ltype, rtype): class TestMixedBinary(HybridBlock): def __init__(self, func): super(TestMixedBinary, self).__init__() @@ -1717,13 +1720,15 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) funcs = { - 'add': (-1.0, 1.0), - 'subtract': (-1.0, 1.0), - 'multiply': (-1.0, 1.0), + 'add': (-1.0, 1.0, None, None), + 'subtract': (-1.0, 1.0, None, None), + 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape), + lambda y, x1, x2: _np.broadcast_to(x1, y.shape)) } shape_pairs = [((3, 2), (3, 2)), ((3, 2), (3, 1)), + ((3, 0), (3, 0)), ((3, 1), (3, 0)), ((0, 2), (1, 2)), ((2, 3, 4), (3, 1)), @@ -1733,16 +1738,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): itypes = [np.bool, np.int8, np.int32, np.int64] ftypes = [np.float16, np.float32, np.float64] for func, func_data in funcs.items(): - low, high = func_data + low, high, lgrad, rgrad = func_data for lshape, rshape in shape_pairs: for type1, type2 in itertools.product(itypes, ftypes): - check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) - check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1) for type1, type2 in itertools.product(ftypes, ftypes): if type1 == type2: continue - check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) @with_seed() @@ -4102,7 +4107,7 @@ def hybrid_forward(self, F, a): mx_out.backward() if (np_out.size == 0): np_backward = _np.zeros(shape) - else: + else: np_backward = np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis) assert x.grad.shape == np_backward.shape assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol)