Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] Fix less/greater bug with scalar input (#18642)
Browse files Browse the repository at this point in the history
* fix ffi

* fix less/greater error

* back

* submodule

* fixed

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
Yiyan66 and Ubuntu authored Jul 4, 2020
1 parent d1b0a09 commit 6462887
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7171,8 +7171,9 @@ def greater(x1, x2, out=None):
>>> np.greater(1, np.ones(1))
array([False])
"""
return _ufunc_helper(x1, x2, _npi.greater, _np.greater, _npi.greater_scalar,
_npi.less_scalar, out)
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.greater(x1, x2, out=out)
return _api_internal.greater(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
Expand Down
34 changes: 28 additions & 6 deletions src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,53 @@ MXNET_REGISTER_API("_npi.not_equal")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret,
const nnvm::Op* op, const nnvm::Op* op_scalar,
const nnvm::Op* op_rscalar) {
if (args[0].type_code() == kNDArrayHandle &&
args[1].type_code() == kNDArrayHandle) {
UFuncHelper(args, ret, op, nullptr, nullptr);
} else if (args[0].type_code() == kNDArrayHandle) {
UFuncHelper(args, ret, nullptr, op_scalar, nullptr);
} else {
UFuncHelper(args, ret, nullptr, nullptr, op_rscalar);
}
}

MXNET_REGISTER_API("_npi.greater")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_greater");
const nnvm::Op* op_scalar = Op::Get("_npi_greater_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.less")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less");
const nnvm::Op* op_scalar = Op::Get("_npi_less_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_less_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
const nnvm::Op* op_rscalar = Op::Get("_npi_greater_scalar");
SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.greater_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_greater_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_greater_equal_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

MXNET_REGISTER_API("_npi.less_equal")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_less_equal");
const nnvm::Op* op_scalar = Op::Get("_npi_less_equal_scalar");
const nnvm::Op* op_rscalar = Op::Get("_npi_less_equal_scalar");
UFuncHelper(args, ret, op, op_scalar, op_rscalar);
const nnvm::Op* op_rscalar = Op::Get("_npi_greater_equal_scalar");
SetUFuncHelper(args, ret, op, op_scalar, op_rscalar);
});

} // namespace mxnet
8 changes: 8 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1947,6 +1947,8 @@ def _add_workload_greater(array_pool):
# OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('greater', array_pool['4x1'], 2)
OpArgMngr.add_workload('greater', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))

Expand All @@ -1956,6 +1958,8 @@ def _add_workload_greater_equal(array_pool):
# OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], 2)
OpArgMngr.add_workload('greater_equal', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))

Expand All @@ -1965,6 +1969,8 @@ def _add_workload_less(array_pool):
# OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('less', array_pool['4x1'], 2)
OpArgMngr.add_workload('less', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))

Expand All @@ -1974,6 +1980,8 @@ def _add_workload_less_equal(array_pool):
# OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float32), np.array([-2, 5, 1, 4, 3], dtype=np.float32))
OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('less_equal', array_pool['4x1'], 2)
OpArgMngr.add_workload('less_equal', 2, array_pool['4x1'])
# TODO(junwu): mxnet currently does not have a consistent behavior as NumPy in dealing with np.nan
# OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))

Expand Down

0 comments on commit 6462887

Please sign in to comment.