diff --git a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc index 9fdfc11d8eb1..224843358526 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -47,25 +47,13 @@ MXNET_REGISTER_API("_npi.not_equal") void SetUFuncHelper(runtime::MXNetArgs args, runtime::MXNetRetValue* ret, const nnvm::Op* op, const nnvm::Op* op_scalar, const nnvm::Op* op_rscalar) { - int result = 0; - if (args[1].type_code() == kNDArrayHandle) { - result++; - result <<= 1; - } - if (args[0].type_code() == kNDArrayHandle) { - result++; - } - - switch (result) { - case 3 : - UFuncHelper(args, ret, op, nullptr, nullptr); - break; - case 1 : - UFuncHelper(args, ret, nullptr, op_scalar, nullptr); - break; - case 2 : - UFuncHelper(args, ret, nullptr, nullptr, op_rscalar); - break; + if (args[0].type_code() == kNDArrayHandle && + args[1].type_code() == kNDArrayHandle) { + UFuncHelper(args, ret, op, nullptr, nullptr); + } else if (args[0].type_code() == kNDArrayHandle) { + UFuncHelper(args, ret, nullptr, op_scalar, nullptr); + } else { + UFuncHelper(args, ret, nullptr, nullptr, op_rscalar); } }