From 47ea6849efb13b186d4c8d6138e9de78db58fddc Mon Sep 17 00:00:00 2001 From: Ziyi Mu Date: Fri, 29 May 2020 14:51:17 -0700 Subject: [PATCH] Revert PR 17767 for fixing GPU memory usage regression (#18283) (#18309) * Revert "Fix and optimize handling of vectorized memory accesses (#17767)" This reverts commit 5542d03695b4a2589afb88acf128d4ba8ac94d0d. * add license to reverted file --- 3rdparty/mshadow/mshadow/base.h | 48 +++ 3rdparty/mshadow/mshadow/half2.h | 162 +++++++++ src/common/cuda_vectorization.cuh | 283 --------------- src/operator/mshadow_op.h | 67 ++++ src/operator/tensor/elemwise_binary_op.cuh | 322 ------------------ src/operator/tensor/elemwise_binary_op.h | 206 +++++------ .../tensor/elemwise_binary_op_basic.cu | 23 +- .../tensor/elemwise_binary_scalar_op.cuh | 207 ----------- .../tensor/elemwise_binary_scalar_op.h | 75 +--- .../tensor/elemwise_binary_scalar_op_basic.cu | 9 +- .../elemwise_binary_scalar_op_extended.cu | 15 +- src/operator/tensor/elemwise_sum.cu | 112 +----- src/operator/tensor/elemwise_sum.h | 12 + src/operator/tensor/elemwise_unary_op.cuh | 127 ------- src/operator/tensor/elemwise_unary_op.h | 56 ++- .../tensor/elemwise_unary_op_basic.cu | 1 - src/operator/tensor/elemwise_unary_op_pow.cu | 1 - src/operator/tensor/elemwise_unary_op_trig.cu | 1 - tests/python/unittest/test_operator.py | 81 +---- 19 files changed, 464 insertions(+), 1344 deletions(-) create mode 100755 3rdparty/mshadow/mshadow/half2.h delete mode 100644 src/common/cuda_vectorization.cuh delete mode 100644 src/operator/tensor/elemwise_binary_op.cuh delete mode 100644 src/operator/tensor/elemwise_binary_scalar_op.cuh delete mode 100644 src/operator/tensor/elemwise_unary_op.cuh diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index 6469bbc34f37..9f538574f093 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -295,6 +295,7 @@ extern "C" { } #include "./half.h" +#include "./half2.h" #include "./bfloat.h" #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP) \ MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, mshadow::bfloat::bf16_t b) { \ @@ -409,6 +410,11 @@ struct DataType { #endif }; template<> +struct DataType { + static const int kFlag = kFloat16; + static const int kLanes = 2; +}; +template<> struct DataType { static const int kFlag = kBfloat16; static const int kLanes = 1; @@ -1161,6 +1167,48 @@ struct minimum { } #endif +#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...) \ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half2_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...) \ switch (type) { \ case mshadow::kFloat32: \ diff --git a/3rdparty/mshadow/mshadow/half2.h b/3rdparty/mshadow/mshadow/half2.h new file mode 100755 index 000000000000..cecc5449383c --- /dev/null +++ b/3rdparty/mshadow/mshadow/half2.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2017 by Contributors + * \file half2.h + * \brief definition of vector float16, half2 type. + * + * \author Antti-Pekka Hynninen + */ +#ifndef MSHADOW_HALF2_H_ +#define MSHADOW_HALF2_H_ + +#if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) + #define MSHADOW_CUDA_HALF2 1 + #include +#else + #define MSHADOW_CUDA_HALF2 0 +#endif + +#include + +/*! \brief namespace for mshadow */ +namespace mshadow { +/* \brief name space for host/device portable half-precision floats */ +namespace half { + +#define MSHADOW_HALF2_ASSIGNOP(AOP, OP) \ + template \ + MSHADOW_XINLINE half2_t operator AOP (const T& a) { \ + return *this = half2_t(*this OP a); /* NOLINT(*)*/ \ + } \ + +class MSHADOW_ALIGNED(4) half2_t { + public: +#if MSHADOW_CUDA_HALF2 + half2 half2_; +#else + half_t half_t2[2]; +#endif + + MSHADOW_XINLINE half2_t() {} + +#if MSHADOW_CUDA_HALF2 + MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {} +#else + MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) { + half_t2[0] = a; + half_t2[1] = b; + } +#endif + + MSHADOW_XINLINE explicit half2_t(int a) { +#if MSHADOW_CUDA_HALF2 + half2_ = __half2half2(__int2half_rz(a)); +#else + half_t2[0] = (half_t)a; + half_t2[1] = (half_t)a; +#endif + } + + MSHADOW_XINLINE half2_t operator+() { + return *this; + } + + MSHADOW_XINLINE half2_t operator-() { +#if MSHADOW_CUDA_HALF2 + return half2_t(__hneg2(half2_)); +#else + return half2_t(-half_t2[0], -half_t2[1]); +#endif + } + + MSHADOW_XINLINE half2_t operator=(const half2_t& a) { +#if MSHADOW_CUDA_HALF2 + half2_ = a.half2_; +#else + half_t2[0] = a.half_t2[0]; + half_t2[1] = a.half_t2[1]; +#endif + return a; + } + + MSHADOW_HALF2_ASSIGNOP(+=, +) + MSHADOW_HALF2_ASSIGNOP(-=, -) + MSHADOW_HALF2_ASSIGNOP(*=, *) + MSHADOW_HALF2_ASSIGNOP(/=, /) +}; + +/*! \brief overloaded + operator for half2_t */ +MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) + __low2float(b.half2_), + __high2float(a.half2_) + __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]); +#endif +} +/*! \brief overloaded - operator for half2_t */ +MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) - __low2float(b.half2_), + __high2float(a.half2_) - __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]); +#endif +} +/*! \brief overloaded * operator for half2_t */ +MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) * __low2float(b.half2_), + __high2float(a.half2_) * __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]); +#endif +} +/*! \brief overloaded / operator for half2_t */ +MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(__low2float(a.half2_) / __low2float(b.half2_), + __high2float(a.half2_) / __high2float(b.half2_))); +#else + return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]); +#endif +} +/*! \brief overloaded % operator for half2_t */ +MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), __low2float(b.half2_)), + ::fmod(__high2float(a.half2_), __high2float(b.half2_)))); +#else + return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], b.half_t2[1])); +#endif +} +/*! \brief overloaded == operator for half2_t */ +MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) { +#if MSHADOW_CUDA_HALF2 + return __hbeq2(a.half2_, b.half2_); +#else + return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]); +#endif +} + +} // namespace half +} // namespace mshadow +#endif // MSHADOW_HALF2_H_ diff --git a/src/common/cuda_vectorization.cuh b/src/common/cuda_vectorization.cuh deleted file mode 100644 index 7803afb901ab..000000000000 --- a/src/common/cuda_vectorization.cuh +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2020 by Contributors - * \file cuda_vectorization.cuh - * \brief GPU helpers for vectorized memory accesses - */ - -#ifndef MXNET_COMMON_CUDA_VECTORIZATION_CUH_ -#define MXNET_COMMON_CUDA_VECTORIZATION_CUH_ - -#if MXNET_USE_CUDA && __CUDACC__ - -#include -#include "cuda_utils.h" - - -namespace mxnet { -namespace common { -namespace cuda { - -/* \brief Helper class that enables storing multiple values of type DType - as 1 value of type LType. -*/ -template -class VectorizedStorage { - public: - constexpr static int nvec = sizeof(LType) / sizeof(DType); - union vectorized_storage { - LType aligned; - DType separate[nvec]; // NOLINT(*) - - MSHADOW_XINLINE vectorized_storage() {} - MSHADOW_XINLINE ~vectorized_storage() {} - } scratch_; -}; - -/* \brief Helper class that enables accessing multiple values of type DType - as 1 value of type LType. Additional aligned template argument - allows performance optimizations if the pointer and the size of - the allocation is aligned to sizeof(LType) / sizeof(DType) elements. -*/ -template -class VectorizedAccessor { - public: - using StorageType = VectorizedStorage::type, - typename std::remove_const::type>; - StorageType storage_; - - LType* aligned_ptr_; - DType* unaligned_ptr_; - int alignment_; - index_t n_elems_; - - MSHADOW_XINLINE VectorizedAccessor(DType* ptr, const index_t size) { - unaligned_ptr_ = ptr; - if (aligned) { - alignment_ = 0; - aligned_ptr_ = reinterpret_cast(ptr); - n_elems_ = (size + storage_.nvec - 1) / storage_.nvec; - } else { - size_t ptr_as_number = reinterpret_cast(ptr); - alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType); - aligned_ptr_ = reinterpret_cast(ptr - alignment_); - n_elems_ = (size + alignment_ + storage_.nvec - 1) / storage_.nvec; - } - } - - /* \brief Alignment of the input pointer in elements. */ - MSHADOW_XINLINE int alignment() const { - return alignment_; - } - - /* \brief Access to separate elements. */ - MSHADOW_XINLINE DType* separate() { - return storage_.scratch_.separate; - } - - /* \brief Number of elements stored. */ - MSHADOW_XINLINE constexpr int nvec() const { - return storage_.nvec; - } - - /* \brief Number of aligned elements that span the entire input tensor. */ - MSHADOW_XINLINE index_t num_aligned_elements() const { - return n_elems_; - } - - /* \brief Load values from the input. - \param id Aligned index of the element. - \param N size of the tensor. - */ - MSHADOW_XINLINE void load(const index_t id, const index_t N) { - if (aligned) { - storage_.scratch_.aligned = aligned_ptr_[id]; - } else { - if (id > 0 && id < n_elems_ - 1) { - storage_.scratch_.aligned = aligned_ptr_[id]; - } else { -#pragma unroll - for (int j = 0; j < storage_.nvec; ++j) { - DType* ptr = reinterpret_cast(&(aligned_ptr_[id])) + j; - if (reinterpret_cast(ptr) >= reinterpret_cast(unaligned_ptr_) && - reinterpret_cast(ptr) < reinterpret_cast(unaligned_ptr_ + N)) { - storage_.scratch_.separate[j] = *ptr; - } - } - } - } - } -}; - -/* \brief Class used for vectorized read-only access. */ -template -class VectorizedLoader : public VectorizedAccessor { - public: - MSHADOW_XINLINE VectorizedLoader(const DType* ptr, const index_t N) : - VectorizedAccessor(ptr, N) { - } -}; - -/* \brief Class used for vectorized writable access. */ -template -class VectorizedStorer : public VectorizedAccessor { - public: - MSHADOW_XINLINE VectorizedStorer(DType* ptr, const index_t N) : - VectorizedAccessor(ptr, N) { - } - - /* \brief Store values to the output. - \param id Aligned index of the element. - \param N size of the tensor. - */ - MSHADOW_XINLINE void store(const index_t id, const index_t N) { - if (aligned) { - this->aligned_ptr_[id] = this->storage_.scratch_.aligned; - } else { - if (id > 0 && id < this->n_elems_ - 1) { - this->aligned_ptr_[id] = this->storage_.scratch_.aligned; - } else { -#pragma unroll - for (int j = 0; j < this->storage_.nvec; ++j) { - DType* ptr = reinterpret_cast(&(this->aligned_ptr_[id])) + j; - if (reinterpret_cast(ptr) >= reinterpret_cast(this->unaligned_ptr_) && - reinterpret_cast(ptr) < reinterpret_cast(this->unaligned_ptr_ + N)) { - *ptr = this->storage_.scratch_.separate[j]; - } - } - } - } - } -}; - -namespace { - -enum class Alignment { - SAME_ALIGNED, // All tensors aligned - SAME_UNALIGNED, // All tensors have the same misalignment - DIFFERENT // Tensors have different alignment -}; - -template -int CalcAlignment(const DType* ptr) { - size_t ptr_as_number = reinterpret_cast(ptr); - return ptr_as_number % sizeof(LType); -} - -/* \brief Check alignment of the inputs and outputs when cast to LType*. - \param params Structuce containing arrays with inputs' and outputs' pointers - \param lead_dim Leading dimension of the tensors. - \param other_dim The size of the other dimensions of the tensors. -*/ -template -Alignment CheckAlignment(const Params& params, const index_t lead_dim, const index_t other_dim) { - int align = -1; - constexpr int nvec = sizeof(LType) / sizeof(DType); - - for (const DType* ptr : params.inputs) { - int new_align = CalcAlignment(ptr); - if (align == -1) { - align = new_align; - } else { - if (align != new_align) { - return Alignment::DIFFERENT; - } - } - } - - for (const DType* ptr : params.outputs) { - int new_align = CalcAlignment(ptr); - if (align == -1) { - align = new_align; - } else { - if (align != new_align) { - return Alignment::DIFFERENT; - } - } - } - - if ((other_dim != 1) && - (lead_dim % nvec != 0)) { - return Alignment::DIFFERENT; - } - - if ((align == 0) && - (lead_dim % nvec == 0)) { - return Alignment::SAME_ALIGNED; - } else { - return Alignment::SAME_UNALIGNED; - } -} - -constexpr int vectorized_kernel_thread_num = 512; - -} // namespace - -/* \brief Helper launcher function for the vectorized kernels. Checks for alignment of the - input and output tensors and launches a proper template. - \param lead_dim Leading dimension of the tensors. - \param other_dim The size of the other dimensions. - \param s Stream which should be used for launching the kernel. - \param params Input parameters to the kernel. Needs to contain at least 2 arrays of DType*: - inputs and outputs, which contain input and output pointers. -*/ -template -void VectorizedKernelLauncher(const index_t lead_dim, - const index_t other_dim, - mshadow::Stream* s, - typename Kernel::ParamType params) { - static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than operand type"); - if (lead_dim * other_dim != 0) { - cudaStream_t stream = mshadow::Stream::GetStream(s); - VectorizedLoader l(params.inputs[0], lead_dim); - size_t num_elements = other_dim * l.num_aligned_elements(); - constexpr int threads = vectorized_kernel_thread_num; - constexpr int max_blocks = 65535; - index_t blocks = std::min(static_cast((num_elements + threads - 1) / threads), - max_blocks); - auto align = CheckAlignment(params, lead_dim, other_dim); - switch (align) { - case Alignment::SAME_ALIGNED: - Kernel::template Launch(blocks, threads, stream, params, lead_dim, other_dim); - break; - case Alignment::SAME_UNALIGNED: - Kernel::template Launch(blocks, threads, stream, params, lead_dim, other_dim); - break; - case Alignment::DIFFERENT: { - const index_t size = lead_dim * other_dim; - index_t blocks = std::min(static_cast((size + threads - 1) / - threads), - max_blocks); - // If the pointers are aligned differently we cannot vectorize - Kernel::template Launch(blocks, threads, stream, params, lead_dim, other_dim); - break; - } - } - } -} - -} // namespace cuda -} // namespace common -} // namespace mxnet - -#endif // MXNET_USE_CUDA && __CUDACC__ - -#endif // MXNET_COMMON_CUDA_VECTORIZATION_CUH_ diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index e0bbb4e5c935..2d4d49254676 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -730,8 +730,22 @@ MXNET_BINARY_MATH_OP(rminus, b - a); MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b)); +template<> +MSHADOW_XINLINE mshadow::half::half2_t div_grad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + return mshadow::half::half2_t(1) / b; +} + MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b)); +template<> +MSHADOW_XINLINE mshadow::half::half2_t div_rgrad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + return -a / (b * b); +} + MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a)); MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a)); @@ -782,6 +796,13 @@ struct mod : public mxnet_op::tunable { }; +template<> +MSHADOW_XINLINE mshadow::half::half2_t mod::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + return a%b; +} + struct mod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { @@ -803,6 +824,19 @@ MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map mshadow::half::half_t b) { return mshadow::half::half_t(1.0f); } +template<> +MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + mshadow::half::half2_t result = mshadow::half::half2_t(); +#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) + result.half2_ = ::__float2half2_rn(1.0f); +#else + result.half_t2[0] = mshadow::half::half_t(0.0f); + result.half_t2[1] = mshadow::half::half_t(1.0f); +#endif + return result; +} struct mod_rgrad : public mxnet_op::tunable { template @@ -825,6 +859,19 @@ MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map mshadow::half::half_t b) { return mshadow::half::half_t(-::floorf(static_cast(a/b))); } +template<> +MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { +#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) + return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_))); +#else + return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( + static_cast(a.half_t2[0]/b.half_t2[0]))), + mshadow::half::half_t(-::floorf( + static_cast(a.half_t2[1]/b.half_t2[1])))); +#endif +} struct rmod : public mxnet_op::tunable { template @@ -861,6 +908,13 @@ struct rmod : public mxnet_op::tunable { } }; +template<> +MSHADOW_XINLINE mshadow::half::half2_t rmod::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + return b%a; +} + struct rmod_grad { template MSHADOW_XINLINE static DType Map(DType a, DType b) { @@ -882,6 +936,19 @@ MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map mshadow::half::half_t b) { return mshadow::half::half_t(-::floorf(static_cast(b/a))); } +template<> +MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { +#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2) + return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_))); +#else + return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( + static_cast(b.half_t2[0]/a.half_t2[0]))), + mshadow::half::half_t(-::floorf( + static_cast(b.half_t2[1]/a.half_t2[1])))); +#endif +} struct clip : public mxnet_op::tunable { template diff --git a/src/operator/tensor/elemwise_binary_op.cuh b/src/operator/tensor/elemwise_binary_op.cuh deleted file mode 100644 index 0bb9fa636f45..000000000000 --- a/src/operator/tensor/elemwise_binary_op.cuh +++ /dev/null @@ -1,322 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2020 by Contributors - * \file elemwise_binary_op.cuh - * \brief GPU helpers for elementwise operators - */ - -#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_CUH_ -#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_CUH_ - -#include -#include "../operator_common.h" -#include "../../common/cuda_vectorization.cuh" - -#include - -#if MXNET_USE_CUDA - -namespace mxnet { -namespace op { - -namespace binary { - -using common::cuda::VectorizedKernelLauncher; -using common::cuda::VectorizedLoader; -using common::cuda::VectorizedStorer; - -template -struct VectorizedBinaryKernelParams { - const DType* inputs[NumInputs]; - DType* outputs[NumOutputs]; -}; - -template -__global__ void VectorizedBinaryKernelFwd(const VectorizedBinaryKernelParams params, - const index_t N) { - VectorizedLoader loader0(params.inputs[0], N); - VectorizedLoader loader1(params.inputs[1], N); - VectorizedStorer storer(params.outputs[0], N); - - const index_t M = loader0.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - loader0.load(tid, N); - loader1.load(tid, N); - if (req == kAddTo) { - storer.load(tid, N); - } -#pragma unroll - for (int i = 0; i < loader0.nvec(); ++i) { - DType temp = OP::Map(loader0.separate()[i], - loader1.separate()[i]); - - if (req == kAddTo) { - storer.separate()[i] += temp; - } else { - storer.separate()[i] = temp; - } - } - storer.store(tid, N); - } -} - -template -__global__ void VectorizedBinaryKernelBwdUseNone( - const VectorizedBinaryKernelParams params, - const index_t N) { - VectorizedLoader loader(params.inputs[0], N); - VectorizedStorer lstorer(params.outputs[0], N); - VectorizedStorer rstorer(params.outputs[1], N); - - const index_t M = loader.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - loader.load(tid, N); - if (lreq == kAddTo) { - lstorer.load(tid, N); - } - if (rreq == kAddTo) { - rstorer.load(tid, N); - } -#pragma unroll - for (int i = 0; i < loader.nvec(); ++i) { - DType inp = loader.separate()[i]; - if (!((std::is_same::value && lreq == kWriteInplace) || - lreq == kNullOp)) { - DType ltemp = LOP::Map(inp); - if (lreq == kAddTo) { - lstorer.separate()[i] += ltemp; - } else { - lstorer.separate()[i] = ltemp; - } - lstorer.store(tid, N); - } - if (!((std::is_same::value && rreq == kWriteInplace) || - rreq == kNullOp)) { - DType rtemp = ROP::Map(inp); - - if (rreq == kAddTo) { - rstorer.separate()[i] += rtemp; - } else { - rstorer.separate()[i] = rtemp; - } - rstorer.store(tid, N); - } - } - } -} - -template -__global__ void VectorizedBinaryKernelBwdUseIn( - const VectorizedBinaryKernelParams params, - const index_t N) { - VectorizedLoader ograd_loader(params.inputs[0], N); - VectorizedLoader linput_loader(params.inputs[1], N); - VectorizedLoader rinput_loader(params.inputs[2], N); - VectorizedStorer lstorer(params.outputs[0], N); - VectorizedStorer rstorer(params.outputs[1], N); - - const index_t M = ograd_loader.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - ograd_loader.load(tid, N); - linput_loader.load(tid, N); - rinput_loader.load(tid, N); - if (lreq == kAddTo) { - lstorer.load(tid, N); - } - if (rreq == kAddTo) { - rstorer.load(tid, N); - } -#pragma unroll - for (int i = 0; i < ograd_loader.nvec(); ++i) { - DType ograd = ograd_loader.separate()[i]; - DType linput = linput_loader.separate()[i]; - DType rinput = rinput_loader.separate()[i]; - if (!(lreq == kNullOp)) { - DType ltemp = ograd * LOP::Map(linput, rinput); - if (lreq == kAddTo) { - lstorer.separate()[i] += ltemp; - } else { - lstorer.separate()[i] = ltemp; - } - lstorer.store(tid, N); - } - if (!(rreq == kNullOp)) { - DType rtemp = ograd * ROP::Map(linput, rinput); - - if (rreq == kAddTo) { - rstorer.separate()[i] += rtemp; - } else { - rstorer.separate()[i] = rtemp; - } - rstorer.store(tid, N); - } - } - } -} - -template -class VectorizedBinaryFwd { - public: - using ParamType = VectorizedBinaryKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedBinaryKernelFwd - <<>>(params, lead_dim); - } -}; - -template -class VectorizedBinaryBwdUseNone { - public: - using ParamType = VectorizedBinaryKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedBinaryKernelBwdUseNone - <<>>(params, lead_dim); - } -}; - -template -class VectorizedBinaryBwdUseIn { - public: - using ParamType = VectorizedBinaryKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedBinaryKernelBwdUseIn - <<>>(params, lead_dim); - } -}; - -} // namespace binary - -template -void ElemwiseBinaryOp::Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream *s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace binary; - if (req[0] == kNullOp) return; - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using LType = uint4; - using Kernel = VectorizedBinaryFwd; - - const index_t size = outputs[0].Size(); - typename Kernel::ParamType params; - params.inputs[0] = inputs[0].dptr(); - params.inputs[1] = inputs[1].dptr(); - params.outputs[0] = outputs[0].dptr(); - - VectorizedKernelLauncher(size, 1, s, params); - }); - }); -} - -template -void ElemwiseBinaryOp::BackwardUseNone_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace binary; - cudaStream_t stream = mshadow::Stream::GetStream(s); - - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - const index_t size = inputs[0].Size(); - if (req[0] != kNullOp || req[1] != kNullOp) { - MXNET_REQ_TYPE_SWITCH(req[0], lreq, { - MXNET_REQ_TYPE_SWITCH(req[1], rreq, { - using LType = uint4; - using Kernel = VectorizedBinaryBwdUseNone; - - typename Kernel::ParamType params; - params.inputs[0] = inputs[0].dptr(); - params.outputs[0] = outputs[0].dptr(); - params.outputs[1] = outputs[1].dptr(); - - VectorizedKernelLauncher(size, 1, s, params); - }); - }); - } - }); -} - -template -void ElemwiseBinaryOp::BackwardUseIn_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace binary; - if (req[0] != kNullOp || req[1] != kNullOp) { - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MXNET_REQ_TYPE_SWITCH(req[0], lreq, { - MXNET_REQ_TYPE_SWITCH(req[1], rreq, { - const index_t size = inputs[0].Size(); - // Using 64 bit loads to reduce register pressure - using LType = uint2; - using Kernel = VectorizedBinaryBwdUseIn; - - typename Kernel::ParamType params; - params.inputs[0] = inputs[0].dptr(); - params.inputs[1] = inputs[1].dptr(); - params.inputs[2] = inputs[2].dptr(); - params.outputs[0] = outputs[0].dptr(); - params.outputs[1] = outputs[1].dptr(); - - VectorizedKernelLauncher(size, 1, s, params); - }); - }); - }); - } -} - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_CUDA -#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_CUH_ diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index b9396aee204e..bc5140a5d75f 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -106,85 +106,62 @@ class ElemwiseBinaryOp : public OpBase { } private: - template + template static void BackwardUseNone_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, + const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using namespace mxnet_op; - const int size = static_cast((outputs[0].Size() + DataType::kLanes - 1) - / DataType::kLanes); - const DType *ograd_dptr = inputs[0].dptr(); - if (std::is_same::value && req[0] == kWriteInplace) { - CHECK_EQ(ograd_dptr, outputs[0].dptr()); - } else if (req[0] != kNullOp) { - DType *lgrad_dptr = outputs[0].dptr(); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - Kernel, cpu>::Launch(s, size, lgrad_dptr, ograd_dptr); - }); - } - if (std::is_same::value && req[1] == kWriteInplace) { - CHECK_EQ(ograd_dptr, outputs[1].dptr()); - } else if (req[1] != kNullOp) { - DType *rgrad_dptr = outputs[1].dptr(); - MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { - Kernel, cpu>::Launch(s, size, rgrad_dptr, ograd_dptr); - }); - } - }); - } -#if MXNET_USE_CUDA - template - static void BackwardUseNone_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); -#endif - - template - static void BackwardUseIn_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - DCHECK_EQ(outputs.size(), 2U); - DCHECK_EQ(inputs.size(), 3U); - const DType *ograd_dptr = inputs[0].dptr(); - const DType *lhs_dptr = inputs[1].dptr(); - const DType *rhs_dptr = inputs[2].dptr(); + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + const int size = static_cast((outputs[0].Size() + DataType::kLanes - 1) + / DataType::kLanes); + const DType *ograd_dptr = inputs[0].dptr(); + if (std::is_same::value && req[0] == kWriteInplace) { + CHECK_EQ(ograd_dptr, outputs[0].dptr()); + } else if (req[0] != kNullOp) { + DType *lgrad_dptr = outputs[0].dptr(); MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - const int size = static_cast( - (outputs[0].Size() + mxnet_op::DataType::kLanes - 1) - / mxnet_op::DataType::kLanes); - DType * lgrad_dptr = outputs[0].dptr(); - mxnet_op::Kernel< - mxnet_op::op_with_req, Req>, cpu>::Launch( - s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr); + Kernel, xpu>::Launch(s, size, lgrad_dptr, ograd_dptr); }); + } + if (std::is_same::value && req[1] == kWriteInplace) { + CHECK_EQ(ograd_dptr, outputs[1].dptr()); + } else if (req[1] != kNullOp) { + DType *rgrad_dptr = outputs[1].dptr(); MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { - const int size = static_cast( - (outputs[1].Size() + mxnet_op::DataType::kLanes - 1) - / mxnet_op::DataType::kLanes); - DType * rgrad_dptr = outputs[1].dptr(); - mxnet_op::Kernel< - mxnet_op::op_with_req, Req>, cpu>::Launch( - s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr); + Kernel, xpu>::Launch(s, size, rgrad_dptr, ograd_dptr); }); - }); + } } -#if MXNET_USE_CUDA - template + template static void BackwardUseIn_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, + const OpContext &ctx, const std::vector &inputs, const std::vector &req, - const std::vector &outputs); -#endif + const std::vector &outputs) { + DCHECK_EQ(outputs.size(), 2U); + DCHECK_EQ(inputs.size(), 3U); + mxnet_op::Stream *s = ctx.get_stream(); + const DType *ograd_dptr = inputs[0].dptr(); + const DType *lhs_dptr = inputs[1].dptr(); + const DType *rhs_dptr = inputs[2].dptr(); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + const int size = static_cast( + (outputs[0].Size() + mxnet_op::DataType::kLanes - 1) + / mxnet_op::DataType::kLanes); + DType * lgrad_dptr = outputs[0].dptr(); + mxnet_op::Kernel, Req>, xpu>::Launch( + s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);}); + MXNET_ASSIGN_REQ_SWITCH(req[1], Req, { + const int size = static_cast( + (outputs[1].Size() + mxnet_op::DataType::kLanes - 1) + / mxnet_op::DataType::kLanes); + DType * rgrad_dptr = outputs[1].dptr(); + mxnet_op::Kernel, Req>, xpu>::Launch( + s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);}); + } template< typename xpu, @@ -521,13 +498,15 @@ class ElemwiseBinaryOp : public OpBase { }); } - template - static void Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream *s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + template + static void Compute(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mxnet_op; + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); if (outputs[0].type_flag_ == mshadow::kBool) { @@ -538,7 +517,7 @@ class ElemwiseBinaryOp : public OpBase { const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + DataType::kLanes - 1) / DataType::kLanes; if (size != 0) { - Kernel, cpu>::Launch(s, size, + Kernel, xpu>::Launch(s, size, outputs[0].dptr(), inputs[0].dptr(), inputs[1].dptr()); } @@ -546,26 +525,6 @@ class ElemwiseBinaryOp : public OpBase { }); } -#if MXNET_USE_CUDA - template - static void Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream *s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); -#endif - - template - static void Compute(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - if (req[0] == kNullOp) return; - mshadow::Stream *s = ctx.get_stream(); - Compute_(attrs, s, inputs, req, outputs); - } - template static void ComputeWithBool(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -615,6 +574,30 @@ class ElemwiseBinaryOp : public OpBase { }); } + template + static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); + } + template static void ComputeEx(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -711,8 +694,20 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - mshadow::Stream *s = ctx.get_stream(); - BackwardUseNone_(attrs, s, inputs, req, outputs); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BackwardUseNone_(attrs, ctx, inputs, req, outputs); + }); + } + + template + static inline void BackwardUseNoneWithHalf2(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + BackwardUseNone_(attrs, ctx, inputs, req, outputs); + }); } template @@ -756,8 +751,20 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - mshadow::Stream *s = ctx.get_stream(); - BackwardUseIn_(attrs, s, inputs, req, outputs); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BackwardUseIn_(attrs, ctx, inputs, req, outputs); + }); + } + + template + static inline void BackwardUseInWithHalf2(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + BackwardUseIn_(attrs, ctx, inputs, req, outputs); + }); } template< @@ -856,9 +863,4 @@ class ElemwiseBinaryOp : public OpBase { } // namespace op } // namespace mxnet - -#ifdef __CUDACC__ -#include "elemwise_binary_op.cuh" -#endif // __CUDACC__ - #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_ diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index b21b08d03217..16d7fc1ad72b 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -218,51 +218,52 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream *s, } NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", ElemwiseBinaryOp::Compute) +.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); NNVM_REGISTER_OP(_grad_add) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2); NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseNone); NNVM_REGISTER_OP(elemwise_sub) -.set_attr("FCompute", ElemwiseBinaryOp::Compute) +.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2< + gpu, op::mshadow_op::minus>) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); NNVM_REGISTER_OP(_backward_sub) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseNone); NNVM_REGISTER_OP(elemwise_mul) -.set_attr("FCompute", ElemwiseBinaryOp::Compute) +.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeDnsLRValueEx); NNVM_REGISTER_OP(_backward_mul) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseIn); NNVM_REGISTER_OP(elemwise_div) .set_attr("FCompute", - ElemwiseBinaryOp::Compute); + ElemwiseBinaryOp::ElemwiseBinaryOp::ComputeWithHalf2); NNVM_REGISTER_OP(_backward_div) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseIn); NNVM_REGISTER_OP(_mod) -.set_attr("FCompute", ElemwiseBinaryOp::Compute); +.set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2); NNVM_REGISTER_OP(_backward_mod) .set_attr("FCompute", - ElemwiseBinaryOp::BackwardUseIn); + ElemwiseBinaryOp::BackwardUseInWithHalf2); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cuh b/src/operator/tensor/elemwise_binary_scalar_op.cuh deleted file mode 100644 index 062c18767ac6..000000000000 --- a/src/operator/tensor/elemwise_binary_scalar_op.cuh +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2020 by Contributors - * \file elemwise_binary_scalar_op.cuh - * \brief GPU helpers for binary elementwise operators with scalar - */ - -#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_CUH_ -#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_CUH_ - -#include -#include "../operator_common.h" -#include "../../common/cuda_vectorization.cuh" - -#include - -#if MXNET_USE_CUDA - -namespace mxnet { -namespace op { - -namespace binary_scalar { - -using common::cuda::VectorizedKernelLauncher; -using common::cuda::VectorizedLoader; -using common::cuda::VectorizedStorer; - -template -struct VectorizedKernelParams { - const DType* inputs[NumInputs]; - DType* outputs[NumOutputs]; - DType scalar; -}; - -template -__global__ void VectorizedBinaryScalarKernelFwd(const VectorizedKernelParams params, - const index_t N) { - VectorizedLoader loader0(params.inputs[0], N); - VectorizedStorer storer(params.outputs[0], N); - - const index_t M = loader0.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - loader0.load(tid, N); - if (req == kAddTo) { - storer.load(tid, N); - } -#pragma unroll - for (int i = 0; i < loader0.nvec(); ++i) { - DType temp = OP::Map(loader0.separate()[i], - params.scalar); - - if (req == kAddTo) { - storer.separate()[i] += temp; - } else { - storer.separate()[i] = temp; - } - } - storer.store(tid, N); - } -} - -template -__global__ void VectorizedBinaryScalarKernelBwd(const VectorizedKernelParams params, - const index_t N) { - VectorizedLoader ograd_loader(params.inputs[0], N); - VectorizedLoader input_loader(params.inputs[1], N); - VectorizedStorer storer(params.outputs[0], N); - - const index_t M = ograd_loader.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - ograd_loader.load(tid, N); - input_loader.load(tid, N); - if (req == kAddTo) { - storer.load(tid, N); - } -#pragma unroll - for (int i = 0; i < ograd_loader.nvec(); ++i) { - DType ograd = ograd_loader.separate()[i]; - DType temp = ograd * OP::Map(input_loader.separate()[i], - params.scalar); - - if (req == kAddTo) { - storer.separate()[i] += temp; - } else { - storer.separate()[i] = temp; - } - } - storer.store(tid, N); - } -} - -template -class VectorizedBinaryScalarFwd { - public: - using ParamType = VectorizedKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedBinaryScalarKernelFwd - <<>>(params, lead_dim); - } -}; - -template -class VectorizedBinaryScalarBwd { - public: - using ParamType = VectorizedKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedBinaryScalarKernelBwd - <<>>(params, lead_dim); - } -}; - -} // namespace binary_scalar - -template -void BinaryScalarOp::Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace binary_scalar; - if (req[0] == kNullOp) return; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - const double alpha = nnvm::get(attrs.parsed); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using LType = uint4; - using Kernel = VectorizedBinaryScalarFwd; - - const index_t size = outputs[0].Size(); - typename Kernel::ParamType params; - params.inputs[0] = inputs[0].dptr(); - params.outputs[0] = outputs[0].dptr(); - params.scalar = (DType)alpha; - - VectorizedKernelLauncher(size, 1, s, params); - }); - }); -} - -template -void BinaryScalarOp::Backward_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace binary_scalar; - if (req[0] == kNullOp) return; - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - const double alpha = nnvm::get(attrs.parsed); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using LType = uint4; - using Kernel = VectorizedBinaryScalarBwd; - - const index_t size = outputs[0].Size(); - typename Kernel::ParamType params; - params.inputs[0] = inputs[0].dptr(); - params.inputs[1] = inputs[1].dptr(); - params.outputs[0] = outputs[0].dptr(); - params.scalar = (DType)alpha; - - VectorizedKernelLauncher(size, 1, s, params); - }); - }); -} - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_CUDA -#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_CUH_ diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 53161ee2354f..4eaaff09d83d 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -225,44 +225,26 @@ class BinaryScalarOp : public UnaryOp { } public: - template - static void Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + template + static void Compute(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { DCHECK_EQ(inputs.size(), 1); DCHECK_EQ(outputs.size(), 1); using namespace mshadow; using namespace mshadow::expr; + Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - mxnet_op::Kernel, cpu>::Launch( + mxnet_op::Kernel, xpu>::Launch( s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); }); }); } -#if MXNET_USE_CUDA - template - static void Compute_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); -#endif - - template - static void Compute(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - mshadow::Stream *s = ctx.get_stream(); - Compute_(attrs, s, inputs, req, outputs); - } - template static void ComputeInt(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -354,46 +336,26 @@ class BinaryScalarOp : public UnaryOp { } } - template - static void Backward_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + template + static void Backward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; using namespace mshadow::expr; + Stream *s = ctx.get_stream(); const double alpha = nnvm::get(attrs.parsed); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { mxnet::op::mxnet_op::Kernel, Req>, cpu>:: + mxnet::op::mxnet_op::backward_grad_tuned, Req>, xpu>:: Launch(s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), inputs[1].dptr(), DType(alpha)); }); }); } - -#if MXNET_USE_CUDA - template - static void Backward_(const nnvm::NodeAttrs &attrs, - mshadow::Stream* s, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); -#endif - - template - static void Backward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Backward_(attrs, s, inputs, req, outputs); - } }; #define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name) \ @@ -414,9 +376,4 @@ class BinaryScalarOp : public UnaryOp { } // namespace op } // namespace mxnet - -#ifdef __CUDACC__ -#include "elemwise_binary_scalar_op.cuh" -#endif - #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_ diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu index 3fd017f09ec7..3c839205683a 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu @@ -57,19 +57,22 @@ NNVM_REGISTER_OP(_rdiv_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_rdiv_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward); NNVM_REGISTER_OP(_mod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_mod_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward< + gpu, mshadow_op::mod_grad>); NNVM_REGISTER_OP(_rmod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_rmod_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward< + gpu, mshadow_op::rmod_grad>); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cu b/src/operator/tensor/elemwise_binary_scalar_op_extended.cu index f09e40a2eee7..2bd52d7b9d7c 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cu +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cu @@ -44,25 +44,30 @@ NNVM_REGISTER_OP(_power_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_power_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward< + gpu, mshadow_op::power_grad>); NNVM_REGISTER_OP(_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_rpower_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward< + gpu, mshadow_op::rpower_grad>); NNVM_REGISTER_OP(_hypot_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_hypot_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward< + gpu, mshadow_op::hypot_grad_left>); NNVM_REGISTER_OP(smooth_l1) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarOp::Compute< + gpu, mshadow_op::smooth_l1_loss>); NNVM_REGISTER_OP(_backward_smooth_l1) -.set_attr("FCompute", BinaryScalarOp::Backward); +.set_attr("FCompute", BinaryScalarOp::Backward< + gpu, mshadow_op::smooth_l1_gradient>); } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_sum.cu b/src/operator/tensor/elemwise_sum.cu index 352c74ea9445..f9a248214e85 100644 --- a/src/operator/tensor/elemwise_sum.cu +++ b/src/operator/tensor/elemwise_sum.cu @@ -24,118 +24,10 @@ */ #include "./elemwise_sum.h" #include "../../ndarray/ndarray_function.h" -#include "../../common/cuda_vectorization.cuh" namespace mxnet { namespace op { -using common::cuda::VectorizedKernelLauncher; -using common::cuda::VectorizedLoader; -using common::cuda::VectorizedStorer; - -namespace { - -constexpr size_t num_inputs_per_kernel = 4; - -template -struct VectorizedElementwiseSumKernelParams { - int num_inputs; - const DType* inputs[NumInputs]; - DType* outputs[1]; -}; - -template -__launch_bounds__(mxnet::common::cuda::vectorized_kernel_thread_num) -__global__ void VectorizedElementwiseSumKernel( - const VectorizedElementwiseSumKernelParams params, - const index_t N) { - VectorizedStorer storer(params.outputs[0], N); - - const index_t M = storer.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - if (req == kAddTo) { - storer.load(tid, N); - } else { -#pragma unroll - for (int i = 0; i < storer.nvec(); ++i) { - storer.separate()[i] = 0; - } - } -#pragma unroll - for (int i = 0; i < num_inputs_per_kernel; ++i) { - if (i < params.num_inputs) { - VectorizedLoader loader(params.inputs[i], N); - loader.load(tid, N); -#pragma unroll - for (int i = 0; i < loader.nvec(); ++i) { - storer.separate()[i] += loader.separate()[i]; - } - } - } - - storer.store(tid, N); - } -} - - -template -class VectorizedElementwiseSumFwd { - public: - using ParamType = VectorizedElementwiseSumKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedElementwiseSumKernel - <<>>(params, lead_dim); - } -}; - -void VectorizedElementwiseSum(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - mshadow::Stream *s = ctx.get_stream(); - if (req[0] == kNullOp) return; - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using LType = uint2; - const index_t size = inputs[0].Size(); - for (size_t i = 0; i < inputs.size(); i += num_inputs_per_kernel) { - if (i == 0) { - using Kernel = VectorizedElementwiseSumFwd; - typename Kernel::ParamType params; - params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i); - for (int j = 0; j < params.num_inputs; ++j) { - params.inputs[j] = inputs[i + j].dptr(); - } - params.outputs[0] = outputs[0].dptr(); - VectorizedKernelLauncher(size, 1, s, params); - } else { - /* During subsequent launches we need to - accumulate into the previous outputs - */ - using Kernel = VectorizedElementwiseSumFwd; - typename Kernel::ParamType params; - params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - i); - for (int j = 0; j < params.num_inputs; ++j) { - params.inputs[j] = inputs[i + j].dptr(); - } - params.outputs[0] = outputs[0].dptr(); - VectorizedKernelLauncher(size, 1, s, params); - } - } - }); - }); -} - void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -159,10 +51,8 @@ void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs, } } -} // namespace - NNVM_REGISTER_OP(add_n) -.set_attr("FCompute", VectorizedElementwiseSum) +.set_attr("FCompute", ElementWiseSumComputeWithHalf2) .set_attr("FComputeEx", ElementWiseSumComputeExGPU); } // namespace op diff --git a/src/operator/tensor/elemwise_sum.h b/src/operator/tensor/elemwise_sum.h index d40ab4de0f0f..259c80ddddac 100644 --- a/src/operator/tensor/elemwise_sum.h +++ b/src/operator/tensor/elemwise_sum.h @@ -113,6 +113,18 @@ void ElementWiseSumCompute(const nnvm::NodeAttrs& attrs, }); } +template +void ElementWiseSumComputeWithHalf2(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(outputs.size(), 1U); + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + ElementWiseSumCompute_(attrs, ctx, inputs, req, outputs); + }); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_SUM_H_ diff --git a/src/operator/tensor/elemwise_unary_op.cuh b/src/operator/tensor/elemwise_unary_op.cuh deleted file mode 100644 index 8688a8b8ac66..000000000000 --- a/src/operator/tensor/elemwise_unary_op.cuh +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2020 by Contributors - * \file elemwise_unary_op.cuh - * \brief GPU helpers for unary elementwise operators - */ - -#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_CUH_ -#define MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_CUH_ - -#include -#include "../operator_common.h" -#include "../../common/cuda_vectorization.cuh" - -#include - -#if MXNET_USE_CUDA - -namespace mxnet { -namespace op { - -namespace unary { - -using common::cuda::VectorizedKernelLauncher; -using common::cuda::VectorizedLoader; -using common::cuda::VectorizedStorer; - -template -struct VectorizedKernelParams { - const DType* inputs[NumInputs]; - DType* outputs[NumOutputs]; -}; - -template -__global__ void VectorizedUnaryScalarKernelFwd(const VectorizedKernelParams params, - const index_t N) { - VectorizedLoader loader(params.inputs[0], N); - VectorizedStorer storer(params.outputs[0], N); - - const index_t M = loader.num_aligned_elements(); - - for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x; - tid < M; - tid += gridDim.x * blockDim.x) { - loader.load(tid, N); - if (req == kAddTo) { - storer.load(tid, N); - } -#pragma unroll - for (int i = 0; i < loader.nvec(); ++i) { - DType temp = OP::Map(loader.separate()[i]); - - if (req == kAddTo) { - storer.separate()[i] += temp; - } else { - storer.separate()[i] = temp; - } - } - storer.store(tid, N); - } -} - -template -class VectorizedUnaryScalarFwd { - public: - using ParamType = VectorizedKernelParams; - - template - static void Launch(const index_t blocks, const index_t threads, - cudaStream_t stream, - const ParamType params, const index_t lead_dim, - const index_t /* other_dim */) { - VectorizedUnaryScalarKernelFwd - <<>>(params, lead_dim); - } -}; - -} // namespace unary - -template -void UnaryOp::Compute_(const nnvm::NodeAttrs& attrs, - mshadow::Stream* s, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace unary; - if (req[0] == kNullOp) return; - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using LType = uint4; - using Kernel = VectorizedUnaryScalarFwd; - - const index_t size = outputs[0].Size(); - typename Kernel::ParamType params; - params.inputs[0] = inputs[0].dptr(); - params.outputs[0] = outputs[0].dptr(); - - VectorizedKernelLauncher(size, 1, s, params); - }); - }); -} - -} // namespace op -} // namespace mxnet - -#endif // MXNET_USE_CUDA -#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_CUH_ diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 86686c6f1278..dcbd53aac69b 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -235,32 +235,6 @@ class UnaryOp : public OpBase { } } - template - static void Compute_(const nnvm::NodeAttrs& attrs, - mshadow::Stream* s, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (inputs[0].Size() != 0) { - mxnet_op::Kernel, cpu>::Launch( - s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); - } - }); - }); - } - -#if MXNET_USE_CUDA - template - static void Compute_(const nnvm::NodeAttrs& attrs, - mshadow::Stream* s, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs); - -#endif - template static void Compute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -268,7 +242,14 @@ class UnaryOp : public OpBase { const std::vector& req, const std::vector& outputs) { mshadow::Stream *s = ctx.get_stream(); - Compute_(attrs, s, inputs, req, outputs); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + if (inputs[0].Size() != 0) { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr()); + } + }); + }); } template @@ -363,6 +344,23 @@ class UnaryOp : public OpBase { } #endif + template + static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + Kernel::Launch(s, outputs[0].Size(), + outputs[0].dptr(), inputs[0].dptr()); + }); + } + template static void IdentityCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -879,8 +877,4 @@ void NumpyNanToNumOpBackward(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet -#ifdef __CUDACC__ -#include "elemwise_unary_op.cuh" -#endif - #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_H_ diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index 7c0550735519..e5b60b1726e6 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -22,7 +22,6 @@ * \brief GPU Implementation of unary functions. */ #include "./elemwise_binary_op.h" -#include "./elemwise_unary_op.h" namespace mxnet { namespace op { diff --git a/src/operator/tensor/elemwise_unary_op_pow.cu b/src/operator/tensor/elemwise_unary_op_pow.cu index 287a2e87be73..4dbdf349cdb0 100644 --- a/src/operator/tensor/elemwise_unary_op_pow.cu +++ b/src/operator/tensor/elemwise_unary_op_pow.cu @@ -22,7 +22,6 @@ * \brief GPU Implementation of power (x^k for fixed k) functions. */ #include "./elemwise_binary_op.h" -#include "./elemwise_unary_op.h" namespace mxnet { namespace op { diff --git a/src/operator/tensor/elemwise_unary_op_trig.cu b/src/operator/tensor/elemwise_unary_op_trig.cu index f5e9d1ccbd6c..8e28b9c609fa 100644 --- a/src/operator/tensor/elemwise_unary_op_trig.cu +++ b/src/operator/tensor/elemwise_unary_op_trig.cu @@ -22,7 +22,6 @@ * \brief GPU Implementation of unary trigonometric function. */ #include "./elemwise_binary_op.h" -#include "./elemwise_unary_op.h" namespace mxnet { namespace op { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e22d529eeb41..c73b8456240b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9895,85 +9895,6 @@ def test_elemwise_sum_for_gradient_accumulation(): assert stored_grad['write'] == stored_grad['add'] assert stored_grad['write'] == 2 * nrepeat -@with_seed() -def test_elementwise_ops_on_misaligned_input(): - a = mx.nd.array([1,2,3,4], dtype='float16') - b = mx.nd.array([1,2,3,4], dtype='float16') - - c = a[1:3] - d = b[1:3] - # Note: testing just elemwise_add since all elemwise_ops - # share the implementation - mx.nd.elemwise_add(c, d, out=c) - mx.nd.waitall() - - a = mx.nd.array([1,2,3,4], dtype='float16') - b = mx.nd.array([1,2,3,4], dtype='float16') - - c = a[0:3] - d = b[0:3] - mx.nd.elemwise_add(c, d, out=c) - mx.nd.waitall() - assert a[3].asscalar() == 4.0 - -@with_seed() -def test_broadcast_ops_on_misaligned_input(): - dtypes = ['float16', 'float32', 'float64'] - lead_dims = [2,3,4,6,10] - - for dtype in dtypes: - for lead_dim in lead_dims: - for both_ways in [False, True]: - shape = list(rand_shape_2d()) + [lead_dim] - small_shape = [shape[0], 1, lead_dim] - if both_ways: - # Broadcast in both ways [1, K, L] x [M, 1, L] - big_shape = [1, shape[1], lead_dim] - else: - big_shape = shape - size = np.product(shape) - small_size = np.product(small_shape) - big_size = np.product(big_shape) - a = mx.nd.arange(5000) - b = mx.nd.arange(5000) - e = mx.nd.arange(5000) - c = a[1:big_size + 1].reshape(big_shape) - d = b[1:small_size + 1].reshape(small_shape) - f = e[1:size + 1].reshape(shape) - mx.nd.broadcast_add(c, d, out=f) - expected = c.asnumpy() + d.asnumpy() - mx.nd.waitall() - assert_almost_equal(f, expected) - -@with_seed() -def test_broadcast_ops_on_misaligned_input_oneside(): - dtypes = ['float16', 'float32', 'float64'] - lead_dims = [2,3,4,6,10] - - for dtype in dtypes: - for lead_dim in lead_dims: - for both_ways in [False, True]: - shape = list(rand_shape_2d()) + [lead_dim] - small_shape = [shape[0], shape[1], 1] - if both_ways: - # Broadcast in both ways [1, K, L] x [M, 1, 1] - big_shape = [1, shape[1], lead_dim] - else: - big_shape = shape - size = np.product(shape) - small_size = np.product(small_shape) - big_size = np.product(big_shape) - a = mx.nd.arange(5000) - b = mx.nd.arange(5000) - e = mx.nd.arange(5000) - c = a[1:big_size + 1].reshape(big_shape) - d = b[1:small_size + 1].reshape(small_shape) - f = e[1:size + 1].reshape(shape) - mx.nd.broadcast_add(c, d, out=f) - expected = c.asnumpy() + d.asnumpy() - mx.nd.waitall() - assert_almost_equal(f, expected) - def test_scalarop_locale_invariance(): arr = mx.nd.zeros((1,)) prev = locale.getlocale(locale.LC_NUMERIC) @@ -9993,7 +9914,7 @@ def test_scalarop_locale_invariance(): break except locale.Error as e: print("Couldn't enable locale", loc, ": ", str(e)) - + if locale_set: scalar = 0.3 assert "," in locale.str(scalar)