From 0dc3ba60b310a7e8915d235613e88bdbf5bf21b5 Mon Sep 17 00:00:00 2001 From: masahi Date: Sun, 19 Dec 2021 04:58:15 +0900 Subject: [PATCH] Refactor GELU and Sigmoid epilogue to use a common template (and add SiLu, Hardswish epilogue) (#379) * Support half precision sigmoid activation * introduce a vectorized variant using fast_tanh * refactored sigmoid using the new interface * refactored gelu * add silu activation * add hardswish * remove sigmoid for now * add description to silu and hardswish, and other doc update * Do not ignore Round * use constant N * Set isHeavy = true in sigmoid and silu epilogue --- include/cutlass/epilogue/thread/activation.h | 63 +++++- .../epilogue/thread/linear_combination_gelu.h | 158 +------------ .../thread/linear_combination_generic.h | 209 ++++++++++++++++++ .../thread/linear_combination_hardswish.h | 62 ++++++ .../thread/linear_combination_sigmoid.h | 158 +------------ .../epilogue/thread/linear_combination_silu.h | 62 ++++++ 6 files changed, 403 insertions(+), 309 deletions(-) create mode 100644 include/cutlass/epilogue/thread/linear_combination_generic.h create mode 100644 include/cutlass/epilogue/thread/linear_combination_hardswish.h create mode 100644 include/cutlass/epilogue/thread/linear_combination_silu.h diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index a4e73b17e9..ce34be6397 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -126,6 +126,61 @@ struct Sigmoid > { } }; +// SiLu (swish) operator introduced by Elfwing et al. in the following paper +// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017) +// https://arxiv.org/pdf/1702.03118.pdf +// It is used in EfficientNet and YOLOv5, for example. +// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html +template +struct SiLu { + CUTLASS_HOST_DEVICE + T operator()(T const &scalar) const { + return scalar * Sigmoid(scalar); + } +}; + +template +struct SiLu> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + Sigmoid> sigmoid_op; + multiplies> mul; + return mul(rhs, sigmoid_op(rhs)); + } +}; + +// Hardswish operator introduced by Howard et al. in the following paper +// "Searching for MobileNetV3" (2019) +// https://arxiv.org/pdf/1905.02244.pdf +// It is used in models based on MobilenetNetV3. +// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html +template +struct HardSwish { + CUTLASS_HOST_DEVICE + T operator()(T const &x) const { + minimum mn; + maximum mx; + T relu6 = mn(mx(x + T(3), T(0)), T(6)); + return x * (relu6 / T(6)); + } +}; + +template +struct HardSwish > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + Array y; + HardSwish hardswish_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = hardswish_op(rhs[i]); + } + + return y; + } +}; + // // GELU function definitions implemented as described by // Hendrycks, D., and Gimpel, K. in @@ -189,7 +244,7 @@ struct GELU_taylor { T k0 = T(0.7978845608028654); T k1 = T(0.044715); - return T(cutlass::constants::half() * z * + return T(cutlass::constants::half() * z * (cutlass::constants::one() + fast_tanh(k0 * z * (cutlass::constants::one() + k1 * z * z)))); } }; @@ -199,10 +254,10 @@ struct GELU_taylor > { static const bool kIsHeavy=true; CUTLASS_HOST_DEVICE Array operator()(Array const &z) const { - + using T = half_t; Array y; - + half_t k0 = half_t(0.7978845608028654); half_t k1 = half_t(0.044715); @@ -250,7 +305,7 @@ struct dGELU { T tanh_out = fast_tanh(k0 * z * (1 + k1 * z * z)); - T ff = constants::half() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) + + T ff = constants::half() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) + constants::half() * (1 + tanh_out); return ff * d_t; diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h index 9eec618179..2bf05b7b20 100644 --- a/include/cutlass/epilogue/thread/linear_combination_gelu.h +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -29,12 +29,8 @@ #pragma once #include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,9 +40,9 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Applies a linear combination operator to an array of elements. +/// Applies a linear combination operator followed by the GELU activation to an array of elements. /// -/// D = alpha * accumulator + beta * source + uniform +/// D = gelu(alpha * accumulator + beta * source + uniform) /// template < typename ElementOutput_, ///< Data type used to load and store tensors @@ -57,153 +53,9 @@ template < typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > -class LinearCombinationGELU { -public: - - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - - static bool const kIsHeavy = true; - - static int const kCount = Count; - - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; - - static FloatRoundStyle const kRound = Round; - - /// Host-constructable parameters structure - struct Params { - - ElementCompute alpha; ///< scales accumulators - ElementCompute beta; ///< scales source tensor - ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory - ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - alpha(ElementCompute(1)), - beta(ElementCompute(0)), - alpha_ptr(nullptr), - beta_ptr(nullptr) { } - - CUTLASS_HOST_DEVICE - Params( - ElementCompute alpha, - ElementCompute beta - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { - - } - - CUTLASS_HOST_DEVICE - Params( - ElementCompute const *alpha_ptr, - ElementCompute const *beta_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { - - } - }; - -private: - - // - // Data members - // - - ElementCompute alpha_; - ElementCompute beta_; - -public: - - /// Constructs the function object, possibly loading from pointers in host memory - CUTLASS_HOST_DEVICE - LinearCombinationGELU(Params const ¶ms) { - - alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); - beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { - return beta_ != ElementCompute(0); - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - if (k_partition) { - beta_ = ElementCompute(1); - } - - CUTLASS_UNUSED(k_partition_count); - } - - /// Computes: D = gelu( alpha * accumulator + beta * source ) - CUTLASS_HOST_DEVICE - FragmentOutput operator()( - FragmentAccumulator const &accumulator, - FragmentOutput const &source) const { - - // Convert source to interal compute numeric type - NumericArrayConverter source_converter; - NumericArrayConverter accumulator_converter; - - ComputeFragment converted_source = source_converter(source); - ComputeFragment converted_accumulator = accumulator_converter(accumulator); - - // Perform binary operations - - ComputeFragment intermediate; - - multiplies mul_add_source; - multiply_add mul_add_accumulator; - GELU gelu; - - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X - - intermediate = gelu(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; - - return destination_converter(intermediate); - } - - /// Computes: D = gelu( alpha * accumulator ) - CUTLASS_HOST_DEVICE - FragmentOutput operator()( - FragmentAccumulator const &accumulator) const { - - // Convert source to interal compute numeric type - NumericArrayConverter accumulator_converter; - - ComputeFragment converted_accumulator = accumulator_converter(accumulator); - - // Perform binary operations - - ComputeFragment intermediate; - - multiplies mul_add_accumulator; - GELU gelu; - - intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum - - intermediate = gelu(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; +using LinearCombinationGELU = LinearCombinationGeneric; - return destination_converter(intermediate); - } -}; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination_generic.h b/include/cutlass/epilogue/thread/linear_combination_generic.h new file mode 100644 index 0000000000..17f961e83b --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_generic.h @@ -0,0 +1,209 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator followed by an activation function to an array of elements. +/// +/// D = activation(alpha * accumulator + beta * source + uniform) +/// +template < + template class ActivationFunctor, + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, + bool IsHeavy = false +> +class LinearCombinationGeneric { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static bool const kIsHeavy = IsHeavy; + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationGeneric(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + ActivationFunctor activation; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + intermediate = activation(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_add_accumulator; + ActivationFunctor activation; + + intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + intermediate = activation(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/thread/linear_combination_hardswish.h b/include/cutlass/epilogue/thread/linear_combination_hardswish.h new file mode 100644 index 0000000000..e6c37d506a --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_hardswish.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with HardSwish operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator followed by the HardSwish activation to an array of elements. +/// +/// D = hardswish(alpha * accumulator + beta * source + uniform) +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +using LinearCombinationHardSwish = LinearCombinationGeneric; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h index 4716effaa6..e5ef55c80d 100644 --- a/include/cutlass/epilogue/thread/linear_combination_sigmoid.h +++ b/include/cutlass/epilogue/thread/linear_combination_sigmoid.h @@ -23,18 +23,14 @@ * **************************************************************************************************/ /*! \file - \brief Functor performing linear combination operations used by epilogues. + \brief Functor performing linear combination with Sigmoid operations used by epilogues. */ #pragma once #include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" -#include "cutlass/array.h" -#include "cutlass/functional.h" -#include "cutlass/numeric_conversion.h" - #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,9 +40,9 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Applies a linear combination operator to an array of elements. +/// Applies a linear combination operator followed by the Sigmoid activation, to an array of elements. /// -/// D = alpha * accumulator + beta * source + uniform +/// D = sigmoid(alpha * accumulator + beta * source + uniform) /// template < typename ElementOutput_, ///< Data type used to load and store tensors @@ -57,150 +53,8 @@ template < typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination FloatRoundStyle Round = FloatRoundStyle::round_to_nearest > -class LinearCombinationSigmoid { -public: - - using ElementOutput = ElementOutput_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - - static int const kCount = Count; - - using FragmentOutput = Array; - using FragmentAccumulator = Array; - using ComputeFragment = Array; - - static FloatRoundStyle const kRound = Round; - - /// Host-constructable parameters structure - struct Params { - - ElementCompute alpha; ///< scales accumulators - ElementCompute beta; ///< scales source tensor - ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory - ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params(): - alpha(ElementCompute(1)), - beta(ElementCompute(0)), - alpha_ptr(nullptr), - beta_ptr(nullptr) { } - - CUTLASS_HOST_DEVICE - Params( - ElementCompute alpha, - ElementCompute beta - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { - - } - - CUTLASS_HOST_DEVICE - Params( - ElementCompute const *alpha_ptr, - ElementCompute const *beta_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { - - } - }; - -private: - - // - // Data members - // - - ElementCompute alpha_; - ElementCompute beta_; - -public: - - /// Constructs the function object, possibly loading from pointers in host memory - CUTLASS_HOST_DEVICE - LinearCombinationSigmoid(Params const ¶ms) { - - alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); - beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const { - return beta_ != ElementCompute(0); - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) { - if (k_partition) { - beta_ = ElementCompute(1); - } - } - - /// Computes linear scaling: D = alpha * accumulator + beta * source - CUTLASS_HOST_DEVICE - FragmentOutput operator()( - FragmentAccumulator const &accumulator, - FragmentOutput const &source) const { - - // Convert source to interal compute numeric type - NumericArrayConverter source_converter; - NumericArrayConverter accumulator_converter; - - ComputeFragment converted_source = source_converter(source); - ComputeFragment converted_accumulator = accumulator_converter(accumulator); - - // Perform binary operations - - ComputeFragment intermediate; - - multiplies mul_add_source; - multiply_add mul_add_accumulator; - Sigmoid sigmoid; - - intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform - intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X - - intermediate = sigmoid(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; - - return destination_converter(intermediate); - } - - /// Computes linear scaling: D = alpha * accumulator - CUTLASS_HOST_DEVICE - FragmentOutput operator()( - FragmentAccumulator const &accumulator) const { - - // Convert source to interal compute numeric type - NumericArrayConverter accumulator_converter; - - ComputeFragment converted_accumulator = accumulator_converter(accumulator); - - // Perform binary operations - - ComputeFragment intermediate; - - multiplies mul_add_accumulator; - Sigmoid sigmoid; - - intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum - - intermediate = sigmoid(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; - - return destination_converter(intermediate); - } -}; - +using LinearCombinationSigmoid = LinearCombinationGeneric; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/thread/linear_combination_silu.h b/include/cutlass/epilogue/thread/linear_combination_silu.h new file mode 100644 index 0000000000..e9a3e2c935 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_silu.h @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with SiLU operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator folllowed by the SiLU activation to an array of elements. +/// +/// D = silu(alpha * accumulator + beta * source + uniform) +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +using LinearCombinationSilu = LinearCombinationGeneric; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass