Skip to content

Commit

Permalink
Refactor GELU and Sigmoid epilogue to use a common template (and add …
Browse files Browse the repository at this point in the history
…SiLu, Hardswish epilogue) (#379)

* Support half precision sigmoid activation

* introduce a vectorized variant using fast_tanh

* refactored sigmoid using the new interface

* refactored gelu

* add silu activation

* add hardswish

* remove sigmoid for now

* add description to silu and hardswish, and other doc update

* Do not ignore Round

* use constant N

* Set isHeavy = true in sigmoid and silu epilogue
  • Loading branch information
masahi authored Dec 18, 2021
1 parent ec4f7e5 commit 0dc3ba6
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 309 deletions.
63 changes: 59 additions & 4 deletions include/cutlass/epilogue/thread/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,61 @@ struct Sigmoid<Array<T, N> > {
}
};

// SiLu (swish) operator introduced by Elfwing et al. in the following paper
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
// https://arxiv.org/pdf/1702.03118.pdf
// It is used in EfficientNet and YOLOv5, for example.
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
template <typename T>
struct SiLu {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return scalar * Sigmoid<T>(scalar);
}
};

template <typename T, int N>
struct SiLu<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Sigmoid<Array<T, N>> sigmoid_op;
multiplies<Array<T, N>> mul;
return mul(rhs, sigmoid_op(rhs));
}
};

// Hardswish operator introduced by Howard et al. in the following paper
// "Searching for MobileNetV3" (2019)
// https://arxiv.org/pdf/1905.02244.pdf
// It is used in models based on MobilenetNetV3.
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
template <typename T>
struct HardSwish {
CUTLASS_HOST_DEVICE
T operator()(T const &x) const {
minimum<T> mn;
maximum<T> mx;
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * (relu6 / T(6));
}
};

template <typename T, int N>
struct HardSwish<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> y;
HardSwish<T> hardswish_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = hardswish_op(rhs[i]);
}

return y;
}
};

//
// GELU function definitions implemented as described by
// Hendrycks, D., and Gimpel, K. in
Expand Down Expand Up @@ -189,7 +244,7 @@ struct GELU_taylor {
T k0 = T(0.7978845608028654);
T k1 = T(0.044715);

return T(cutlass::constants::half<T>() * z *
return T(cutlass::constants::half<T>() * z *
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
}
};
Expand All @@ -199,10 +254,10 @@ struct GELU_taylor<Array<half_t, N> > {
static const bool kIsHeavy=true;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const {

using T = half_t;
Array<half_t, N> y;

half_t k0 = half_t(0.7978845608028654);
half_t k1 = half_t(0.044715);

Expand Down Expand Up @@ -250,7 +305,7 @@ struct dGELU {

T tanh_out = fast_tanh(k0 * z * (1 + k1 * z * z));

T ff = constants::half<T>() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) +
T ff = constants::half<T>() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) +
constants::half<T>() * (1 + tanh_out);

return ff * d_t;
Expand Down
158 changes: 5 additions & 153 deletions include/cutlass/epilogue/thread/linear_combination_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@
#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"

#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -44,9 +40,9 @@ namespace thread {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Applies a linear combination operator to an array of elements.
/// Applies a linear combination operator followed by the GELU activation to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
/// D = gelu(alpha * accumulator + beta * source + uniform)
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
Expand All @@ -57,153 +53,9 @@ template <
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationGELU {
public:

using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;

static bool const kIsHeavy = true;

static int const kCount = Count;

using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;

static FloatRoundStyle const kRound = Round;

/// Host-constructable parameters structure
struct Params {

ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory

//
// Methods
//

CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }

CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {

}

CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {

}
};

private:

//
// Data members
//

ElementCompute alpha_;
ElementCompute beta_;

public:

/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationGELU(Params const &params) {

alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}

/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}

/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}

CUTLASS_UNUSED(k_partition_count);
}

/// Computes: D = gelu( alpha * accumulator + beta * source )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {

// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;

ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);

// Perform binary operations

ComputeFragment intermediate;

multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;

intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X

intermediate = gelu(intermediate);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;

return destination_converter(intermediate);
}

/// Computes: D = gelu( alpha * accumulator )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {

// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;

ComputeFragment converted_accumulator = accumulator_converter(accumulator);

// Perform binary operations

ComputeFragment intermediate;

multiplies<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;

intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum

intermediate = gelu(intermediate);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;

return destination_converter(intermediate);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Loading

0 comments on commit 0dc3ba6

Please sign in to comment.