diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 45f885abf3a0..91fea5f4aeef 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -7171,8 +7171,9 @@ def greater(x1, x2, out=None): >>> np.greater(1, np.ones(1)) array([False]) """ - return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar, - _npi.less_scalar, out) + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.greater(x1, x2, out=out) + return _api_internal.greater(x1, x2, out) @set_module('mxnet.ndarray.numpy') diff --git a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc index f0ca4081b2c8..224843358526 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -44,13 +44,35 @@ MXNET_REGISTER_API("_npi.not_equal") UFuncHelper(args, ret, op, op_scalar, nullptr); }); +void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret, + const nnvm::Op* op, const nnvm::Op* op_scalar, + const nnvm::Op* op_rscalar) { + if (args[0].type_code() == kNDArrayHandle && + args[1].type_code() == kNDArrayHandle) { + UFuncHelper(args, ret, op, nullptr, nullptr); + } else if (args[0].type_code() == kNDArrayHandle) { + UFuncHelper(args, ret, nullptr, op_scalar, nullptr); + } else { + UFuncHelper(args, ret, nullptr, nullptr, op_rscalar); + } +} + +MXNET_REGISTER_API("_npi.greater") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_greater"); + const nnvm::Op* op_scalar = Op::Get("_npi_greater_scalar"); + const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); +}); + MXNET_REGISTER_API("_npi.less") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; const nnvm::Op* op = Op::Get("_npi_less"); const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); + const nnvm::Op* op_rscalar = Op::Get("_npi_greater_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); }); MXNET_REGISTER_API("_npi.greater_equal") @@ -58,8 +80,8 @@ MXNET_REGISTER_API("_npi.greater_equal") using namespace runtime; const nnvm::Op* op = Op::Get("_npi_greater_equal"); const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); + const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); }); MXNET_REGISTER_API("_npi.less_equal") @@ -67,8 +89,8 @@ MXNET_REGISTER_API("_npi.less_equal") using namespace runtime; const nnvm::Op* op = Op::Get("_npi_less_equal"); const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar"); - const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar"); - UFuncHelper(args, ret, op, op_scalar, op_rscalar); + const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar"); + SetUFuncHelper(args, ret, op, op_scalar, op_rscalar); }); } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6a2845e0fb24..8b50fc46f036 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1947,6 +1947,8 @@ def _add_workload_greater(array_pool): # OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32)) OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('greater', array_pool['4x1'], 2) + OpArgMngr.add_workload('greater', 2, array_pool['4x1']) # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan # OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan])) @@ -1956,6 +1958,8 @@ def _add_workload_greater_equal(array_pool): # OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32)) OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('greater_equal', array_pool['4x1'], 2) + OpArgMngr.add_workload('greater_equal', 2, array_pool['4x1']) # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan # OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan])) @@ -1965,6 +1969,8 @@ def _add_workload_less(array_pool): # OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32)) OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('less', array_pool['4x1'], 2) + OpArgMngr.add_workload('less', 2, array_pool['4x1']) # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan # OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan])) @@ -1974,6 +1980,8 @@ def _add_workload_less_equal(array_pool): # OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32)) OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('less_equal', array_pool['4x1'], 2) + OpArgMngr.add_workload('less_equal', 2, array_pool['4x1']) # TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan # OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))