diff --git a/.codecov.yml b/.codecov.yml index 97624c21b8fa..70037e6483ff 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -4,6 +4,9 @@ codecov: require_ci_to_pass: yes coverage: + status: + project: off + patch: off precision: 2 round: down range: "70...100" diff --git a/.gitignore b/.gitignore index c50d1ec99b9f..9fafdb13cb7e 100644 --- a/.gitignore +++ b/.gitignore @@ -121,6 +121,10 @@ cmake_install.cmake # Mac OS X .DS_Store +# Windows +windows_package.7z +windows_package + #Notebook Automated Test !tests/nightly/test_tutorial_config.txt !tests/nightly/TestNotebook diff --git a/CMakeLists.txt b/CMakeLists.txt index 8955551dfeb2..0fa8c9c51af5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -169,7 +169,7 @@ if(MSVC) add_definitions(-DDMLC_STRICT_CXX11) add_definitions(-DNOMINMAX) set(CMAKE_C_FLAGS "/MP") - set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} /bigobj") + set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} ${CMAKE_CXX_FLAGS} /bigobj") else() include(CheckCXXCompilerFlag) if(USE_CXX14_IF_AVAILABLE) diff --git a/ci/build_windows.py b/ci/build_windows.py index 2590d211c671..c8d3af515b5a 100755 --- a/ci/build_windows.py +++ b/ci/build_windows.py @@ -31,15 +31,18 @@ import tempfile import time import zipfile +import requests from distutils.dir_util import copy_tree from enum import Enum -from subprocess import check_call +from subprocess import check_call, call from util import * KNOWN_VCVARS = { + # https://gitlab.kitware.com/cmake/cmake/issues/18920 'VS 2015': r'C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\x86_amd64\vcvarsx86_amd64.bat', - 'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat' + 'VS 2017': r'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvarsx86_amd64.bat', + 'VS 2019': r'C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat', } @@ -54,6 +57,8 @@ class BuildFlavour(Enum): CMAKE_FLAGS = { 'WIN_CPU': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -67,6 +72,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_CPU_MKLDNN': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -80,6 +87,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_CPU_MKLDNN_MKL': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -93,6 +102,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_CPU_MKL': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=OFF ' '-DUSE_CUDNN=OFF ' '-DENABLE_CUDA_RTC=OFF ' @@ -106,6 +117,8 @@ class BuildFlavour(Enum): '-DCMAKE_BUILD_TYPE=Release') , 'WIN_GPU': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' '-DENABLE_CUDA_RTC=ON ' @@ -115,11 +128,12 @@ class BuildFlavour(Enum): '-DUSE_LAPACK=ON ' '-DUSE_DIST_KVSTORE=OFF ' '-DMXNET_CUDA_ARCH="5.2" ' - '-DCMAKE_CXX_FLAGS="/FS /MD /O2 /Ob2" ' '-DUSE_MKL_IF_AVAILABLE=OFF ' '-DCMAKE_BUILD_TYPE=Release') , 'WIN_GPU_MKLDNN': ( + '-DCMAKE_C_COMPILER=cl ' + '-DCMAKE_CXX_COMPILER=cl ' '-DUSE_CUDA=ON ' '-DUSE_CUDNN=ON ' '-DENABLE_CUDA_RTC=ON ' @@ -130,7 +144,6 @@ class BuildFlavour(Enum): '-DUSE_DIST_KVSTORE=OFF ' '-DMXNET_CUDA_ARCH="5.2" ' '-DUSE_MKLDNN=ON ' - '-DCMAKE_CXX_FLAGS="/FS /MD /O2 /Ob2" ' '-DCMAKE_BUILD_TYPE=Release') } @@ -140,39 +153,65 @@ def windows_build(args): logging.info("Using vcvars environment:\n{}".format(args.vcvars)) path = args.output - os.makedirs(path, exist_ok=True) mxnet_root = get_mxnet_root() logging.info("Found MXNet root: {}".format(mxnet_root)) - url = 'https://github.com/Kitware/CMake/releases/download/v3.16.1/cmake-3.16.1-win64-x64.zip' - with tempfile.TemporaryDirectory() as tmpdir: - cmake_file_path = download_file(url, tmpdir) - with zipfile.ZipFile(cmake_file_path, 'r') as zip_ref: - # Create $tmpdir\cmake-3.16.1-win64-x64\bin\cmake.exe - zip_ref.extractall(tmpdir) + if 'GPU' in args.flavour: + # Get Thrust version to be shipped in Cuda 11, due to flakyness of + # older Thrust versions with MSVC 19 compiler + with remember_cwd(): + tmpdirname = tempfile.mkdtemp() + os.chdir(tmpdirname) + r = requests.get('https://github.com/thrust/thrust/archive/1.9.8.zip', allow_redirects=True) + with open('thrust.zip', 'wb') as f: + f.write(r.content) + with zipfile.ZipFile('thrust.zip', 'r') as zip_ref: + zip_ref.extractall('.') + thrust_path = os.path.join(tmpdirname, "thrust-1.9.8") + + + # cuda thrust / CUB + VS 2019 is flaky: try multiple times if fail + MAXIMUM_TRY = 5 + build_try = 0 + + while build_try < MAXIMUM_TRY: + if os.path.exists(path): + shutil.rmtree(path) + os.makedirs(path, exist_ok=True) with remember_cwd(): os.chdir(path) - cmd = "\"{}\" && {} -G \"NMake Makefiles JOM\" {} {}".format( - args.vcvars, - os.path.join(tmpdir, 'cmake-3.16.1-win64-x64', 'bin', 'cmake.exe'), - CMAKE_FLAGS[args.flavour], mxnet_root) + env = os.environ.copy() + if 'GPU' in args.flavour: + env["CXXFLAGS"] = '/FS /MD /O2 /Ob2 /I {}'.format(thrust_path) + env["CUDAFLAGS"] = '-I {}'.format(thrust_path) + cmd = "\"{}\" && cmake -GNinja {} {}".format(args.vcvars, + CMAKE_FLAGS[args.flavour], + mxnet_root) logging.info("Generating project with CMake:\n{}".format(cmd)) - check_call(cmd, shell=True) + check_call(cmd, shell=True, env=env) - cmd = "\"{}\" && jom".format(args.vcvars) - logging.info("Building with jom:\n{}".format(cmd)) + cmd = "\"{}\" && ninja".format(args.vcvars) + logging.info("Building:\n{}".format(cmd)) t0 = int(time.time()) - check_call(cmd, shell=True) + ret = call(cmd, shell=True) + - logging.info( - "Build flavour: {} complete in directory: \"{}\"".format( - args.flavour, os.path.abspath(path))) - logging.info("Build took {}".format( - datetime.timedelta(seconds=int(time.time() - t0)))) - windows_package(args) + if ret != 0: + build_try += 1 + logging.info("{} build(s) have failed".format(build_try)) + else: + logging.info("Build flavour: {} complete in directory: \"{}\"".format(args.flavour, os.path.abspath(path))) + logging.info("Build took {}".format(datetime.timedelta(seconds=int(time.time() - t0)))) + break + + if ret == 0: + windows_package(args) + else: + logging.info("Build failed") + sys.exit(1) def windows_package(args): @@ -233,7 +272,7 @@ def main(): parser.add_argument("--vcvars", help="vcvars batch file location, typically inside vs studio install dir", - default=KNOWN_VCVARS['VS 2015'], + default=KNOWN_VCVARS['VS 2019'], type=str) parser.add_argument("--arch", @@ -258,7 +297,7 @@ def main(): if 'OpenCV_DIR' not in os.environ: os.environ["OpenCV_DIR"] = "C:\\Program Files\\OpenCV-v3.4.1\\build" if 'CUDA_PATH' not in os.environ: - os.environ["CUDA_PATH"] = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" + os.environ["CUDA_PATH"] = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" if 'MKL_ROOT' not in os.environ: os.environ["MKL_ROOT"] = "C:\\Program Files (x86)\\IntelSWTools\\compilers_and_libraries\\windows\\mkl" windows_build(args) diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 6a367b3ccef5..783ad26803ed 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -117,6 +117,16 @@ class Imperative { } return is_np_shape_thread_local_ ? 1 : 0; } + + /*! \brief return current numpy default dtype compatibility status. + * */ + bool is_np_default_dtype() const { + if (is_np_default_dtype_global_) { + return true; + } + return false; + } + /*! \brief specify numpy compatibility off, thread local on or global on. */ bool set_is_np_shape(int is_np_shape) { NumpyShape flag = static_cast(is_np_shape); @@ -215,6 +225,7 @@ class Imperative { static MX_THREAD_LOCAL bool is_np_shape_thread_local_; #endif bool is_np_shape_global_{false}; + bool is_np_default_dtype_global_{false}; /*! \brief node count used for naming */ std::atomic node_count_{0}; /*! \brief variable count used for naming */ diff --git a/python/mxnet/gluon/nn/activations.py b/python/mxnet/gluon/nn/activations.py index 1b9ce91dd2aa..3cccc851e39b 100644 --- a/python/mxnet/gluon/nn/activations.py +++ b/python/mxnet/gluon/nn/activations.py @@ -139,7 +139,8 @@ def __init__(self, alpha_initializer=initializer.Constant(0.25), init=alpha_initializer) def hybrid_forward(self, F, x, alpha): - return F.LeakyReLU(x, gamma=alpha, act_type='prelu', name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, gamma=alpha, act_type='prelu', name='fwd') class ELU(HybridBlock): @@ -167,7 +168,8 @@ def __init__(self, alpha=1.0, **kwargs): self._alpha = alpha def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='elu', slope=self._alpha) + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='elu', slope=self._alpha) class SELU(HybridBlock): @@ -187,7 +189,9 @@ def __init__(self, **kwargs): super(SELU, self).__init__(**kwargs) def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='selu', name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='selu', name='fwd') + class GELU(HybridBlock): r""" @@ -206,7 +210,8 @@ def __init__(self, **kwargs): super(GELU, self).__init__(**kwargs) def hybrid_forward(self, F, x): - return F.LeakyReLU(x, act_type='gelu', name='fwd') + leaky_relu = F.npx.leaky_relu if is_np_array() else F.LeakyReLU + return leaky_relu(x, act_type='gelu', name='fwd') class Swish(HybridBlock): @@ -232,4 +237,7 @@ def __init__(self, beta=1.0, **kwargs): self._beta = beta def hybrid_forward(self, F, x): - return x * F.sigmoid(self._beta * x, name='fwd') + if is_np_array(): + return x * F.npx.sigmoid(self._beta * x) + else: + return x * F.sigmoid(self._beta * x, name='fwd') diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index fceaaf3a282f..9a803d48b5b3 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -6174,6 +6174,11 @@ def clip(a, a_min, a_max, out=None): >>> np.clip(a, 3, 6, out=a) array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) """ + from numbers import Number + if isinstance(a, Number): + # In case input is a scalar, the computation would fall back to native numpy. + # The value returned would be a python scalar. + return _np.clip(a, a_min, a_max, out=None) return _mx_nd_np.clip(a, a_min, a_max, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index da8d5d738f46..5f81a73a4b75 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -1529,13 +1529,15 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou if isinstance(rhs, numeric_types): return fn_scalar(lhs, rhs, out=out) else: + is_int = isinstance(rhs, integer_types) if rfn_scalar is None: # commutative function - return lfn_scalar(rhs, float(lhs), out=out) + return lfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out) else: - return rfn_scalar(rhs, float(lhs), out=out) + return rfn_scalar(rhs, scalar=float(lhs), is_int=is_int, out=out) elif isinstance(rhs, numeric_types): - return lfn_scalar(lhs, float(rhs), out=out) + is_int = isinstance(rhs, integer_types) + return lfn_scalar(lhs, scalar=float(rhs), is_int=is_int, out=out) elif isinstance(rhs, Symbol): return fn_array(lhs, rhs, out=out) else: diff --git a/src/common/utils.h b/src/common/utils.h index 44d4fc3e8772..227830708ff6 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -874,6 +875,11 @@ inline bool is_float(const int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; } +inline bool is_int(const int dtype) { + return dtype == mshadow::kUint8 || dtype == mshadow::kInt8 || + dtype == mshadow::kInt32 || dtype == mshadow::kInt64; +} + inline int get_more_precise_type(const int type1, const int type2) { if (type1 == type2) return type1; if (is_float(type1) && is_float(type2)) { @@ -910,6 +916,19 @@ inline int np_binary_out_infer_type(const int type1, const int type2) { return get_more_precise_type(type1, type2); } +inline int GetDefaultDtype() { + return Imperative::Get()->is_np_default_dtype() ? + mshadow::kFloat64 : + mshadow::kFloat32; +} + +inline int GetDefaultDtype(int dtype) { + if (dtype != -1) return dtype; + return Imperative::Get()->is_np_default_dtype() ? + mshadow::kFloat64 : + mshadow::kFloat32; +} + } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/operator/contrib/gradient_multiplier_op.cc b/src/operator/contrib/gradient_multiplier_op.cc index 0a49ec1c36b3..5221d89a6056 100644 --- a/src/operator/contrib/gradient_multiplier_op.cc +++ b/src/operator/contrib/gradient_multiplier_op.cc @@ -77,9 +77,7 @@ In forward pass it acts as an identity transform. During backpropagation it multiplies the gradient from the subsequent level by a scalar factor lambda and passes it to the preceding layer. )code" ADD_FILELINE) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferStorageType", ElemwiseStorageType<1, 1, false, true, true>) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FComputeEx", UnaryOp::IdentityComputeEx) @@ -88,7 +86,7 @@ the preceding layer. [](const NodeAttrs& attrs){ return std::vector{true}; }) -.add_argument("scalar", "float", "lambda multiplier"); +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_contrib_backward_gradientmultiplier) .set_attr("TIsBackward", true) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 2d4d49254676..61d299d39a40 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -148,7 +148,6 @@ struct true_divide : public mxnet_op::tunable { return static_cast(a) / static_cast(b); } -#ifndef _WIN32 template::value, int>::type = 0> MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { @@ -166,7 +165,6 @@ struct true_divide : public mxnet_op::tunable { MSHADOW_XINLINE static double Map(DType a, double b) { return static_cast(a) / b; } -#endif }; struct rtrue_divide : public mxnet_op::tunable { @@ -182,7 +180,6 @@ struct rtrue_divide : public mxnet_op::tunable { return static_cast(b) / static_cast(a); } -#ifndef _WIN32 template::value, int>::type = 0> MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { @@ -200,14 +197,12 @@ struct rtrue_divide : public mxnet_op::tunable { MSHADOW_XINLINE static double Map(DType a, double b) { return b / static_cast(a); } -#endif }; MXNET_BINARY_MATH_OP_NC(left, a); MXNET_BINARY_MATH_OP_NC(right, b); -#ifndef _WIN32 struct mixed_plus { template::value, int>::type = 0> @@ -345,8 +340,12 @@ struct mixed_rpower { return static_cast(math::pow(b, a)); } }; -#endif +#pragma GCC diagnostic push +#if __GNUC__ >= 7 +#pragma GCC diagnostic ignored "-Wint-in-bool-context" +#pragma GCC diagnostic ignored "-Wbool-compare" +#endif MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b); MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b); @@ -575,7 +574,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a)); MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b)); MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b)); - MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b))); MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b))); @@ -728,6 +726,10 @@ MXNET_BINARY_MATH_OP_NC(minus_sign, a - b > DType(0) ? DType(1) : -DType(1)); MXNET_BINARY_MATH_OP(rminus, b - a); +MXNET_BINARY_MATH_OP_NC(posone, 1); + +MXNET_BINARY_MATH_OP_NC(negone, -1); + MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b)); template<> @@ -795,6 +797,73 @@ struct mod : public mxnet_op::tunable { } }; +struct mixed_mod { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return mod::Map(static_cast(a), b); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return mod::Map(static_cast(a), b); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return mod::Map(static_cast(a), b); + } +}; + +struct mixed_rmod { + template::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return mod::Map(b, static_cast(a)); + } + + template::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return mod::Map(b, static_cast(a)); + } + + template::value || + std::is_same::value || + std::is_integral::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return mod::Map(b, static_cast(a)); + } +}; + +struct fmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (b == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(a), static_cast(b))); + } + } +}; + +struct rfmod : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (a == DType(0)) { + return DType(0); + } else { + return DType(::fmod(static_cast(b), static_cast(a))); + } + } +}; template<> MSHADOW_XINLINE mshadow::half::half2_t mod::Map diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f7dbce270c87..0c2092f55d41 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -849,7 +849,13 @@ struct op_with_req { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } -#ifndef _WIN32 + /*! \brief input is two tensors with different type and with a boolean output tensor */ + template::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const LType *lhs, const RType *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + /*! \brief inputs are two tensors with a half_t output tensor */ template::value, int>::type = 0> @@ -903,7 +909,6 @@ struct op_with_req { MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } -#endif /*! \brief inputs are two tensors with a float output tensor */ template& inputs, int j = 0; for (idim = 0; idim < ndim_iter; ++idim) { if (op_axes_arrays[i][idim] == -1 || - opshape[i][op_axes_arrays[i][idim]] == 1) { + (iop != nop && opshape[i][op_axes_arrays[i][idim]] == 1 && + op_axes_arrays[iop][idim] != -1 && + opshape[iop][op_axes_arrays[iop][idim]] != 1)) { remainstride[iop][j++] = iterstride[iop][idim]; } else { opstride[iop][op_axes_arrays[i][idim]] = iterstride[iop][idim]; diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index 74db52d33f03..a54c7cac51a3 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -206,7 +206,8 @@ struct TVMBinaryBroadcastScalarCompute { // scalar param type_codes[1] = kDLFloat; - values[1].v_float64 = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + values[1].v_float64 = param.scalar; // output tensor type_codes[2] = kTVMDLTensorHandle; @@ -225,9 +226,7 @@ struct TVMBinaryBroadcastScalarCompute { NNVM_REGISTER_OP(_npi_##name##_scalar) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ + .set_attr_parser(ParamParser) \ .set_attr("FListInputNames", \ [](const NodeAttrs& attrs) { \ return std::vector{"data"}; \ @@ -240,7 +239,7 @@ struct TVMBinaryBroadcastScalarCompute { }) \ .set_attr("FGradient", MakeZeroGradNodes) \ .add_argument("data", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("scalar", "float", "scalar input") + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC(not_equal); @@ -285,9 +284,12 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_GPU(less_equal); #else -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ - NNVM_REGISTER_OP(_npi_##name##_scalar) \ - .set_attr("FCompute", BinaryScalarOp::ComputeLogic) +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR_LOGIC_CPU(name) \ + NNVM_REGISTER_OP(_npi_##name##_scalar) \ + .set_attr("FCompute", BinaryScalarOp::ComputeLogic) \ + .set_attr("FResourceRequest", [](const NodeAttrs& n) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) #endif // MXNET_USE_TVM_OP diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index ae285caa9094..ae6697c0b23e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,27 +23,26 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ -#include #include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +DMLC_REGISTER_PARAMETER(NumpyBinaryScalarParam); + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, @@ -61,7 +60,6 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, return true; } -#ifndef _WIN32 #define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(2) \ @@ -76,68 +74,62 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ - .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ - .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") -#else -#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(2) \ - .set_num_outputs(1) \ - .set_attr("FListInputNames", \ + .set_attr("FResourceRequest", \ [](const NodeAttrs& attrs) { \ - return std::vector{"lhs", "rhs"}; \ - }) \ - .set_attr("FInferShape", BinaryBroadcastShape) \ - .set_attr("FInferType", NumpyBinaryMixedPrecisionType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}, {1, 0}}; \ + return std::vector{ResourceRequest::kTempSpace}; \ }) \ - .set_attr("FResourceRequest", \ - [](const NodeAttrs& attrs) { \ - return std::vector{ResourceRequest::kTempSpace}; \ - }) \ .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") -#endif MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_add) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute) -#endif -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_sub) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mul) @@ -155,21 +147,33 @@ NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_attr("FCompute", NumpyBinaryBackwardUseIn); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_mod) +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mod"}); + +NNVM_REGISTER_OP(_backward_npi_broadcast_mod) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_power) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool) -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool) -#endif .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_power"}); NNVM_REGISTER_OP(_backward_npi_broadcast_power) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 1e0130494469..a2927cda61ff 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -29,59 +29,50 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_add) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif + +NNVM_REGISTER_OP(_backward_npi_broadcast_add) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_subtract) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastCompute); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastCompute); -#endif + +NNVM_REGISTER_OP(_backward_npi_broadcast_sub) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_multiply) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_mul) .set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_mod) -.set_attr("FCompute", BinaryBroadcastCompute); +.set_attr( + "FCompute", + NumpyBinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_broadcast_mod) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_power) -#ifndef _WIN32 .set_attr( "FCompute", NumpyBinaryBroadcastComputeWithBool); -#else -.set_attr( - "FCompute", - NumpyBinaryBroadcastComputeWithBool); -#endif NNVM_REGISTER_OP(_backward_npi_broadcast_power) .set_attr("FCompute", NumpyBinaryBackwardUseIn* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return in_attrs->at(0) != -1; -} - inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { LOG(FATAL) << "Operator " << op_name << " does not support combination of " << mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2) << " yet..."; } -#ifndef _WIN32 template void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, const OpContext& ctx, @@ -153,7 +142,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { MixedAllRealBinaryElemwiseCompute(attrs.op->name, ctx, lhs, rhs, out, req[0]); @@ -227,13 +215,9 @@ void MixedAllRealBinaryBroadcastCompute(const std::string& op_name, } }); } -#endif -#ifndef _WIN32 + template -#else -template -#endif void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -248,11 +232,9 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; -#ifndef _WIN32 mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); - if (!ndim) { MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { @@ -290,47 +272,34 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, }); } }); + } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } } -#else - mshadow::Stream *s = ctx.get_stream(); - if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - TBlob temp_tblob; - // one is float, the other is bool - CHECK((out.type_flag_ == lhs.type_flag_) || (out.type_flag_ == rhs.type_flag_)) - << "This case out type should be same as the float type"; - if (lhs.type_flag_ == out.type_flag_) { - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - BinaryBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); - } -#endif } -#ifndef _WIN32 template -#else -template -#endif void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -352,18 +321,10 @@ void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, return; } -#ifndef _WIN32 MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#else - MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#endif } -#ifndef _WIN32 template -#else -template -#endif void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -384,12 +345,31 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, BinaryBroadcastComputeWithBool(attrs, ctx, inputs, req, outputs); return; } - -#ifndef _WIN32 + if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + Stream *s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + return; + } MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#else - MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); -#endif } template diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 52d681885d82..60b721e4854d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -30,21 +30,19 @@ namespace mxnet { namespace op { -#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", NumpyBinaryScalarType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) .describe(R"code()code" ADD_FILELINE) @@ -87,9 +85,7 @@ NNVM_REGISTER_OP(_npi_lcm) NNVM_REGISTER_OP(_npi_lcm_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -98,7 +94,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_npi_bitwise_and) @@ -122,9 +118,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and) NNVM_REGISTER_OP(_npi_bitwise_and_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -133,7 +127,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_xor) @@ -175,9 +169,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or) NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -186,15 +178,13 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_or_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseIntType<1, 1>) .set_attr("FInplaceOption", @@ -203,7 +193,7 @@ NNVM_REGISTER_OP(_npi_bitwise_or_scalar) }) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) .set_attr("FCompute", BinaryScalarOp::ComputeInt); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) @@ -275,14 +265,14 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); @@ -363,13 +353,13 @@ NNVM_REGISTER_OP(_backward_npi_ldexp) mshadow_op::ldexp_rgrad>); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); } // namespace op diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 593a698dc93a..a64fb4db12b6 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -917,7 +917,7 @@ void NumpyConcatenateForward(const nnvm::NodeAttrs& attrs, ConcatParam cparam; cparam.num_args = param.num_args; cparam.dim = param.axis.has_value() ? param.axis.value() : 0; - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { ConcatOp op; op.Init(cparam); op.Forward(ctx, data, req, outputs); @@ -950,7 +950,7 @@ void NumpyConcatenateBackward(const nnvm::NodeAttrs& attrs, ConcatParam cparam; cparam.num_args = param.num_args; cparam.dim = param.axis.has_value() ? param.axis.value() : 0; - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { ConcatOp op; op.Init(cparam); op.Backward(ctx, inputs[0], req, data); diff --git a/src/operator/numpy/np_true_divide-inl.h b/src/operator/numpy/np_true_divide-inl.h index e7a1c193d97f..d2dbc82b3e9b 100644 --- a/src/operator/numpy/np_true_divide-inl.h +++ b/src/operator/numpy/np_true_divide-inl.h @@ -29,6 +29,7 @@ #include #include "../../common/utils.h" #include "../tensor/elemwise_binary_broadcast_op.h" +#include "../numpy/np_elemwise_broadcast_op.h" namespace mxnet { namespace op { @@ -46,7 +47,8 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, using namespace mxnet_op; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; if (out.type_flag_ == data.type_flag_) { @@ -57,8 +59,7 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, }); }); } else { -#ifndef _WIN32 - CHECK(out.type_flag_ == mshadow::kFloat32 || out.type_flag_ == mshadow::kFloat64) + CHECK_EQ(out.type_flag_, mxnet::common::GetDefaultDtype()) << "true_divide only supports float32 and float64" " output when input's dtype is " << type_string(inputs[0].type_flag_); @@ -71,13 +72,6 @@ void TrueDivideScalarCompute(const nnvm::NodeAttrs &attrs, }); }); }); -#else - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(data.Size()), s); - TBlob temp_tblob(temp_tensor); - CastCompute(attrs, ctx, {data}, {kWriteTo}, {temp_tblob}); - TrueDivideScalarCompute(attrs, ctx, {temp_tblob}, req, outputs); -#endif } } @@ -119,12 +113,10 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, }); } } else { -#ifndef _WIN32 - // Non-windows case: no usage of temporary space // Case when types of the 2 input tensors are different if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // both lhs and rhs are float types, output type is the more precise one - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { // one is float type, the other is integer type, the output type should be the same as float CHECK_EQ(out.type_flag_, @@ -153,46 +145,8 @@ void TrueDivideElemwiseCompute(const nnvm::NodeAttrs &attrs, } } else { // lhs is integer type, rhs is integer type, output type should be float - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } -#else - // Windows case: using temp space for casting the type - // Case when types of the 2 input tensors are different - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { - // both lhs and rhs are float types, output type is the more precise one - LOG(ERROR) << "not implemented yet..."; - } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - // lhs is float type, rhs is integer type, the output type should be the same as lhs - CHECK_EQ(out.type_flag_, - common::is_float(lhs.type_flag_) ? lhs.type_flag_ : rhs.type_flag_) - << "This case out type should be same as the float type"; - TBlob temp_tblob; - if (common::is_float(lhs.type_flag_)) { - // lhs is the float one - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - TrueDivideElemwiseCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - // rhs is the float one - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - TrueDivideElemwiseCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - // lhs is integer type, rhs is integer type, output type should be float - LOG(ERROR) << "not implemented yet..."; - } -#endif } } @@ -216,7 +170,6 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; -#ifndef _WIN32 BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape oshape = new_oshape.get(); mshadow::Shape lstride = calc_stride(new_lshape.get()); @@ -244,7 +197,7 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } else { if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { // lhs and rhs have different float types, the output is the more precise one - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { // one of lhs and rhs is float, the output is the same type as the float one if (common::is_float(lhs.type_flag_)) { @@ -272,74 +225,10 @@ void TrueDivideBroadcastCompute(const nnvm::NodeAttrs& attrs, } } else { // lhs and rhs have different integer types, the output is float type - LOG(ERROR) << "not implemented yet..."; + LOG(FATAL) << "not implemented yet..."; } } }); -#else - if (lhs.type_flag_ == rhs.type_flag_) { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = calc_stride(new_lshape.get()); - mshadow::Shape rstride = calc_stride(new_rshape.get()); - // When the both inputs have the same data types - if (common::is_float(lhs.type_flag_)) { - // If both inputs are the same float types, output is the same float type - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, DType, { - Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - lhs.dptr(), rhs.dptr(), out.dptr()); - }); - } else { - CHECK_EQ(out.type_flag_, mshadow::kFloat32) - << "true_divide only supports float32 output when input's dtype is " - << type_string(lhs.type_flag_); - MXNET_INT_TYPE_SWITCH(lhs.type_flag_, DType, { - // If both inputs are the same integer types, output is float type - Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - lhs.dptr(), rhs.dptr(), out.dptr()); - }); - } - }); - } else { - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { - // lhs and rhs have different float types, the output is the more precise one - LOG(ERROR) << "not implemented yet..."; - } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { - // one of lhs and rhs is float, the output is the same type as the float one - TBlob temp_tblob; - if (common::is_float(lhs.type_flag_)) { - // lhs is float type, output will be the same float type - CHECK_EQ(lhs.type_flag_, out.type_flag_) - << "lhs should have the same type as out, infer type broken?"; - MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(rhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); - TrueDivideBroadcastCompute( - attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); - } else { - // rhs is float type, output will be the same float type - CHECK_EQ(rhs.type_flag_, out.type_flag_) - << "rhs should have the same type as out, infer type broken?"; - MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { - Tensor temp_tensor = - ctx.requested[0].get_space_typed(mshadow::Shape1(lhs.Size()), s); - temp_tblob = TBlob(temp_tensor); - }); - CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); - TrueDivideBroadcastCompute( - attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); - } - } else { - // lhs and rhs have different integer types, the output is float type - LOG(ERROR) << "not implemented yet..."; - } - } -#endif } } diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 6edfb4dd0901..2f3bf7d5dfb6 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -74,46 +74,46 @@ NNVM_REGISTER_OP(_npi_true_divide) [](const NodeAttrs& attrs){ return std::vector >{{0, 0}, {1, 0}}; }) -#ifdef _WIN32 +.set_attr("FCompute", TrueDivideBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) +.add_argument("lhs", "NDArray-or-Symbol", "Dividend array") +.add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); + + +NNVM_REGISTER_OP(_backward_npi_broadcast_div) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -#endif -.set_attr("FCompute", TrueDivideBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}) -.add_argument("lhs", "NDArray-or-Symbol", "Dividend array") -.add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); +.set_attr("FCompute", NumpyBinaryBackwardUseIn); NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", TrueDivideType<1>) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) -#ifdef _WIN32 -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -#endif .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "float", "scalar input"); +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", TrueDivideType<1>) .set_attr("FInplaceOption", @@ -129,7 +129,7 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_rdiv_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "float", "scalar input"); +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index 7211f4a0a006..c8eccfe140b4 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -31,6 +31,10 @@ namespace op { NNVM_REGISTER_OP(_npi_true_divide) .set_attr("FCompute", TrueDivideBroadcastCompute); +NNVM_REGISTER_OP(_backward_npi_broadcast_div) +.set_attr("FCompute", NumpyBinaryBackwardUseIn); + NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", TrueDivideScalarCompute); diff --git a/src/operator/numpy/np_unique_op.cc b/src/operator/numpy/np_unique_op.cc index 2f57733a72b2..7a299cdd5221 100644 --- a/src/operator/numpy/np_unique_op.cc +++ b/src/operator/numpy/np_unique_op.cc @@ -375,6 +375,7 @@ NNVM_REGISTER_OP(_npi_unique) [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "The input array") .add_arguments(NumpyUniqueParam::__FIELDS__()); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 0cc0dc92f884..de6efbfe5758 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -421,6 +421,8 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::posone); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::negone); // NOLINT() /*! * \brief Tuner objects, *not* automatically generated */ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ffd0f123070a..774c87afaf3c 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -227,8 +227,8 @@ struct binary_broadcast_kernel { } } -#ifndef _WIN32 /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ template::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, @@ -249,6 +249,7 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ + /* used for mixed type binary ops */ template::value && !std::is_pointer::value, int>::type = 0> @@ -268,7 +269,6 @@ struct binary_broadcast_kernel { KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs, rhs[ridx])); } } -#endif }; template diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 4eaaff09d83d..7d65c89f7b0f 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "elemwise_unary_op.h" @@ -36,6 +37,45 @@ namespace mxnet { namespace op { +struct NumpyBinaryScalarParam : public dmlc::Parameter { + double scalar; + bool is_int; + DMLC_DECLARE_PARAMETER(NumpyBinaryScalarParam) { + DMLC_DECLARE_FIELD(scalar) + .set_default(1) + .describe("Scalar input value"); + DMLC_DECLARE_FIELD(is_int) + .set_default(true) + .describe("Indicate whether scalar input is int type"); + } + + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream scalar_s, is_int_s; + scalar_s << scalar; + is_int_s << is_int; + (*dict)["scalar"] = scalar_s.str(); + (*dict)["is_int"] = is_int_s.str(); + } +}; + +inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + if (common::is_int(in_attrs->at(0)) && !scalar_is_int) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } else if (in_attrs->at(0) == mshadow::kBool) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, scalar_is_int ? mshadow::kInt64 : mshadow::kFloat64); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } + return out_attrs->at(0) != -1; +} + class BinaryScalarOp : public UnaryOp { /*! \brief Tensor operation against a scalar with a dense result */ template @@ -45,7 +85,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &input, const OpReqType req, const NDArray &output) { - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; CHECK_EQ(output.shape(), input.shape()); const int64_t row_count = output.shape()[0]; const int64_t items_per_row = output.shape().Size() / row_count; @@ -137,7 +178,8 @@ class BinaryScalarOp : public UnaryOp { const NDArray &output) { CHECK_EQ(output.shape(), input.shape()); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; const DType dense_fill_val = OP::Map(DType(0), DType(alpha)); const TBlob column_indexes = input.aux_data(csr::kIdx); const size_t item_count = column_indexes.Size(); @@ -236,11 +278,23 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + TBlob temp_tblob; + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if ((common::is_int(inputs[0].type_flag_) && !scalar_is_int) || + (inputs[0].type_flag_ == kBool)) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); }); }); } @@ -256,7 +310,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet_op::Kernel, xpu>::Launch( @@ -276,12 +331,23 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); - }); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + bool scalar_is_int = param.is_int; + const double alpha = param.scalar; + TBlob temp_tblob; + if (common::is_int(inputs[0].type_flag_) && !scalar_is_int) { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(inputs[0].Size()), s); + temp_tblob = TBlob(temp_tensor); + CastCompute(attrs, ctx, {inputs[0]}, {kWriteTo}, {temp_tblob}); + } else { + temp_tblob = inputs[0]; + } + MSHADOW_TYPE_SWITCH_WITH_BOOL(temp_tblob.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), temp_tblob.dptr(), DType(alpha)); + }); }); } @@ -345,7 +411,8 @@ class BinaryScalarOp : public UnaryOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet::op::mxnet_op::Kernelparsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") +#define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc index 13014b35c2fe..dc11593255e7 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc @@ -28,22 +28,24 @@ #include "./elemwise_binary_scalar_op.h" #define MXNET_OPERATOR_REGISTER_BINARY_WITH_SCALAR_SUPPORT_WITH_DENSE_RESULT(name) \ - NNVM_REGISTER_OP(name) \ - .set_num_inputs(1) \ - .set_num_outputs(1) \ - .set_attr_parser([](NodeAttrs* attrs) { \ - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); \ - }) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ - .set_attr("FInferType", ElemwiseType<1, 1>) \ - .set_attr("FInferStorageType", \ - BinaryScalarStorageTypeWithDenseResultStorageType) \ - .set_attr("FInplaceOption", \ - [](const NodeAttrs& attrs){ \ - return std::vector >{{0, 0}}; \ - }) \ - .add_argument("data", "NDArray-or-Symbol", "source input") \ - .add_argument("scalar", "float", "scalar input") + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser(ParamParser) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInferStorageType", \ + BinaryScalarStorageTypeWithDenseResultStorageType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) namespace mxnet { namespace op { @@ -65,7 +67,8 @@ static bool BinaryScalarStorageTypeWithDenseResultStorageType(const NodeAttrs& a const NDArrayStorageType instype = static_cast(in_attrs->at(0)); const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx; - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; if (common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { dispatched = storage_type_assign(&out_attrs[0], kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); @@ -189,8 +192,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rdiv_scalar) .add_alias("_RDivScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rdiv_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rdiv_grad>); @@ -200,8 +203,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_mod_scalar) .add_alias("_ModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_mod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::mod_grad>); @@ -211,8 +214,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rmod_scalar) .add_alias("_RModScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rmod_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rmod_grad>); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 7dd8cf41c59e..c08cdcdf0dc8 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -36,8 +36,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_maximum_scalar) .add_alias("_npi_maximum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_maximum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) @@ -47,8 +47,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_minimum_scalar) .add_alias("_npi_minimum_scalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_minimum_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) @@ -57,8 +57,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_power_scalar) .add_alias("_PowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_power_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::power_grad>); @@ -69,8 +69,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rpower_scalar) .add_alias("_RPowerScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::rpower_grad>); @@ -82,8 +82,8 @@ MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar) .add_alias("_HypotScalar"); MXNET_OPERATOR_REGISTER_BINARY(_backward_hypot_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = dmlc::stod(attrs->dict["scalar"]); }) +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward< cpu, mshadow_op::hypot_grad_left>); @@ -109,13 +109,7 @@ Example:: )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } - }) +.set_attr_parser(ParamParser) .set_attr("FInferShape", ElemwiseShape<1, 1>) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FInplaceOption", @@ -128,13 +122,7 @@ Example:: .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_smooth_l1" }); MXNET_OPERATOR_REGISTER_BINARY(_backward_smooth_l1) - .set_attr_parser([](NodeAttrs *attrs) { - if (attrs->dict.find("scalar") != attrs->dict.end()) { - attrs->parsed = dmlc::stod(attrs->dict["scalar"]); - } else { - attrs->parsed = 1.0; - } -}) +.set_attr_parser(ParamParser) .set_attr("FCompute", BinaryScalarOp::Backward); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc index 17e76153ebb2..0594997877d3 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_logic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_logic.cc @@ -46,7 +46,8 @@ static bool BinaryScalarLogicStorageType(const nnvm::NodeAttrs& attrs, const auto in_stype = in_attrs->at(0); auto &out_stype = out_attrs->at(0); bool dispatched = false; - const double alpha = nnvm::get(attrs.parsed); + const NumpyBinaryScalarParam& param = nnvm::get(attrs.parsed); + const double alpha = param.scalar; bool is_sparse = OP::Map(static_cast(0), alpha) == 0; if (!dispatched && in_stype == kDefaultStorage) { // dns -> dns diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 204fde6bd2bd..836bcd1fc11e 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -25,6 +25,7 @@ import mxnet as mx from mxnet import gluon, autograd, np from mxnet.test_utils import use_np, assert_almost_equal, check_gluon_hybridize_consistency +from mxnet.gluon import nn from common import with_seed import random @@ -422,6 +423,55 @@ def hybrid_forward(self, F, valid_length): assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy()) +@with_seed() +@use_np +def test_activations_leakyrelu(): + # Currently, all the activation tests, we will just test for runnable. + act_layer = nn.LeakyReLU(0.1) + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_prelu(): + act_layer = nn.PReLU() + act_layer.initialize() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_elu(): + act_layer = nn.ELU(1.0) + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_selu(): + act_layer = nn.SELU() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_gelu(): + act_layer = nn.GELU() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + + +@with_seed() +@use_np +def test_activations_swish(): + act_layer = nn.Swish() + out = act_layer(mx.np.random.uniform(size=(10,))) + out.asnumpy() + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 32c8521afd51..20da12b12f48 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2548,8 +2548,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): assert y.shape == np_out.shape assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol, use_broadcast=False, equal_nan=True) - if lgrad: + if (ltype in itypes) and (rtype in itypes): + continue y.backward() if ltype not in itypes: assert_almost_equal(mx_test_x1.grad.asnumpy(), @@ -2573,10 +2574,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) funcs = { - 'add': (-1.0, 1.0, None, None), - 'subtract': (-1.0, 1.0, None, None), + 'add': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.ones(y.shape)), + 'subtract': (-1.0, 1.0, lambda y, x1, x2: _np.ones(y.shape), + lambda y, x1, x2: _np.ones(y.shape) * -1), 'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape), lambda y, x1, x2: _np.broadcast_to(x1, y.shape)), + 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), } @@ -2590,8 +2594,6 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): ((2, 3), ()), ((), (2, 3))] - itypes = [np.bool, np.int8, np.int32, np.int64] - ftypes = [np.float16, np.float32, np.float64] for func, func_data in funcs.items(): low, high, lgrad, rgrad = func_data for lshape, rshape in shape_pairs: @@ -2604,6 +2606,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): continue check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + if func == 'subtract' or func == 'mod': + continue + for type1, type2 in itertools.product(itypes, itypes): + if type1 == type2: + continue + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + @with_seed() @use_np @@ -3145,52 +3154,56 @@ def get_new_shape(shape, axis): shape_lst[axis] = random.randint(0, 3) return tuple(shape_lst) - for shape in [(0, 0), (2, 3), (2, 1, 3)]: - for hybridize in [True, False]: - for axis in [0, 1, None]: - for grad_req in ['write', 'add', 'null']: - # test gluon - test_concat = TestConcat(axis=axis) - if hybridize: - test_concat.hybridize() + shapes = [(0, 0), (2, 3), (2, 1, 3)] + hybridizes = [True, False] + axes = [0, 1, None] + grad_reqs = ['write', 'add', 'null'] + dtypes = [np.float32, np.float64, np.bool] + combinations = itertools.product(shapes, hybridizes, axes, grad_reqs, dtypes) - grad_req_c = grad_req - grad_req_d = grad_req - if grad_req == 'null': - ide = random.randint(0, 2) - grad_req_c = 'write' if ide == 0 else 'add' - grad_req_c = 'write' if ide == 1 else 'add' + for shape, hybridize, axis, grad_req, dtype in combinations: + # test gluon + test_concat = TestConcat(axis=axis) + if hybridize: + test_concat.hybridize() - a = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - a.attach_grad(grad_req) - b = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - b.attach_grad(grad_req) - c = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - c.attach_grad(grad_req_c) - d = mx.nd.random.uniform(-1.0, 1.0, shape=get_new_shape(shape, axis)).as_np_ndarray() - d.attach_grad(grad_req_d) - expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + grad_req_c = grad_req + grad_req_d = grad_req + if grad_req == 'null': + ide = random.randint(0, 2) + grad_req_c = 'write' if ide == 0 else 'add' + grad_req_c = 'write' if ide == 1 else 'add' - with mx.autograd.record(): - y = test_concat(a, b, c, d) + a = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + a.attach_grad(grad_req) + b = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + b.attach_grad(grad_req) + c = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + c.attach_grad(grad_req_c) + d = np.random.uniform(-1.0, 1.0, size=get_new_shape(shape, axis)).astype(dtype) + d.attach_grad(grad_req_d) + expected_ret = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - assert y.shape == expected_ret.shape - assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) + with mx.autograd.record(): + y = test_concat(a, b, c, d) - y.backward() - if grad_req != 'null': - assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) - if grad_req != 'null': - assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) - if grad_req_c != 'null': - assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) - if grad_req_d != 'null': - assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3, atol=1e-5) - # test imperative - mx_out = np.concatenate([a, b, c, d], axis=axis) - np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + y.backward() + if grad_req != 'null': + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + if grad_req != 'null': + assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5) + if grad_req_c != 'null': + assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5) + if grad_req_d != 'null': + assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_out = np.concatenate([a, b, c, d], axis=axis) + np_out = _np.concatenate([a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()], axis=axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) @with_seed() @@ -3704,6 +3717,16 @@ def __init__(self, a_min=None, a_max=None): def hybrid_forward(self, F, x): return x.clip(self._a_min, self._a_max) + + # Test scalar case + for _, a_min, a_max, throw_exception in workloads: + a = _np.random.uniform() # A scalar + if throw_exception: + # No need to test the exception case here. + continue + mx_ret = np.clip(a, a_min, a_max) + np_ret = _np.clip(a, a_min, a_max) + assert_almost_equal(mx_ret, np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) for shape, a_min, a_max, throw_exception in workloads: for dtype in dtypes: @@ -6605,7 +6628,7 @@ def hybrid_forward(self, F, a): ((5, 3, 4), True, True, True, 1), ] for dtype in ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64']: - for hybridize in [False]: + for hybridize in [False, True]: for config in configs: test_unique = TestUnique(*config[1:]) if hybridize: @@ -6998,6 +7021,26 @@ def dbg(name, data): # broadcast bug ('ij, ij -> i', [(1, 4), (2, 4)], lambda *args: (_np.sum(args[1], axis=0)[None, :], _np.tile(args[0], [2, 1]))), + # one dimensim bug + ('...ij, ...jk -> ...ik', [(1, 4), (4, 2)], lambda *args: (args[1].sum(axis=1)[None, :], + _np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))), + ('...ij, ...jk -> ...ik', [(2, 4), (4, 2)], lambda *args: (_np.tile(args[1].sum(axis=1)[None, :], [2, 1]), + _np.tile(args[0].sum(axis=0)[: ,None], [1, 2]))), + ('...ij, ...jk -> ...ik', [(3, 2, 1, 4), (3, 2, 4, 2)], lambda *args: ( + args[1].sum(axis=3)[:, :, None, :], + _np.tile(args[0].sum(axis=2)[:, :, :, None], [1, 1, 1, 2]))), + ('...ij, ...ik -> ...jk', [(1, 1, 1, 4), (1, 1, 1, 3)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, :, None], [1, 1, 1, 4]), + _np.tile(args[0].sum(axis=3)[:, :, : ,None], [1, 1, 1, 3]))), + ('...ij, ...jc -> ...ic', [(1, 1, 5, 3), (1, 1, 3, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), + ('...ij, ...jc -> ...ic', [(1, 2, 5, 4), (1, 2, 4, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), + ('...ij, ...jc -> ...ic', [(2, 1, 5, 4), (2, 1, 4, 2)], lambda *args: ( + _np.tile(args[1].sum(axis=3)[:, :, None, :], [1, 1, 5, 1]), + _np.tile(args[0].sum(axis=2)[:, :, : ,None], [1, 1, 1, 2]))), # issue #16576 # commented due to long running time # ('abiz,abjz->abij', [(64, 8, 128, 512), (64, 8, 128, 512)], lambda *args: (_np.matmul(_np.ones((64, 8, 128, 128)), args[1]), diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index a54bcf3e16a9..2aabfbf9b24b 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -440,9 +440,7 @@ def check_cse_on_symbol(sym, expected_savings, check_data, **kwargs): arr2 = mx.random.uniform(shape=shape) arr3 = mx.random.uniform(shape=shape) - check_cse_on_symbol((a+5) + (a+5), expected_savings=1, check_data=True, a=arr1, b=arr2) check_cse_on_symbol((a+1) + (a+2), expected_savings=0, check_data=True, a=arr1, b=arr2) - check_cse_on_symbol((1+a) + (a+1), expected_savings=1, check_data=True, a=arr1, b=arr2) check_cse_on_symbol((a+b) + (a+b), expected_savings=1, check_data=True, a=arr1, b=arr2) check_cse_on_symbol(((a+b)+c) +((a+b)+c), expected_savings=2, check_data=True, a=arr1, b=arr2, c=arr3)