Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor GELU and Sigmoid epilogue to use a common template (and add SiLu, Hardswish epilogue) #379

Merged
merged 12 commits into from
Dec 18, 2021
63 changes: 59 additions & 4 deletions include/cutlass/epilogue/thread/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,61 @@ struct Sigmoid<Array<T, N> > {
}
};

// SiLu (swish) operator introduced by Elfwing et al. in the following paper
// "Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning" (2017)
// https://arxiv.org/pdf/1702.03118.pdf
// It is used in EfficientNet and YOLOv5, for example.
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.SiLU.html
template <typename T>
struct SiLu {
CUTLASS_HOST_DEVICE
T operator()(T const &scalar) const {
return scalar * Sigmoid<T>(scalar);
}
};

template <typename T, int N>
struct SiLu<Array<T, N>> {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Sigmoid<Array<T, N>> sigmoid_op;
multiplies<Array<T, N>> mul;
return mul(rhs, sigmoid_op(rhs));
}
};

// Hardswish operator introduced by Howard et al. in the following paper
// "Searching for MobileNetV3" (2019)
// https://arxiv.org/pdf/1905.02244.pdf
// It is used in models based on MobilenetNetV3.
// Reference: https://pytorch.org/docs/stable/generated/torch.nn.Hardswish.html
template <typename T>
struct HardSwish {
CUTLASS_HOST_DEVICE
T operator()(T const &x) const {
minimum<T> mn;
maximum<T> mx;
T relu6 = mn(mx(x + T(3), T(0)), T(6));
return x * (relu6 / T(6));
}
};

template <typename T, int N>
struct HardSwish<Array<T, N> > {
CUTLASS_HOST_DEVICE
Array<T, N> operator()(Array<T, N> const &rhs) const {
Array<T, N> y;
HardSwish<T> hardswish_op;

CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N; ++i) {
y[i] = hardswish_op(rhs[i]);
}

return y;
}
};

//
// GELU function definitions implemented as described by
// Hendrycks, D., and Gimpel, K. in
Expand Down Expand Up @@ -187,7 +242,7 @@ struct GELU_taylor {
T k0 = T(0.7978845608028654);
T k1 = T(0.044715);

return T(cutlass::constants::half<T>() * z *
return T(cutlass::constants::half<T>() * z *
(cutlass::constants::one<T>() + fast_tanh(k0 * z * (cutlass::constants::one<T>() + k1 * z * z))));
}
};
Expand All @@ -197,10 +252,10 @@ struct GELU_taylor<Array<half_t, N> > {
static const bool kIsHeavy=true;
CUTLASS_HOST_DEVICE
Array<half_t, N> operator()(Array<half_t, N> const &z) const {

using T = half_t;
Array<half_t, N> y;

half_t k0 = half_t(0.7978845608028654);
half_t k1 = half_t(0.044715);

Expand Down Expand Up @@ -248,7 +303,7 @@ struct dGELU {

T tanh_out = fast_tanh(k0 * z * (1 + k1 * z * z));

T ff = constants::half<T>() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) +
T ff = constants::half<T>() * z * ((1 - tanh_out * tanh_out) * (k0 + k2 * z * z)) +
constants::half<T>() * (1 + tanh_out);

return ff * d_t;
Expand Down
158 changes: 5 additions & 153 deletions include/cutlass/epilogue/thread/linear_combination_gelu.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@
#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/array.h"
#include "cutlass/functional.h"
#include "cutlass/numeric_conversion.h"

#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand All @@ -44,9 +40,9 @@ namespace thread {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Applies a linear combination operator to an array of elements.
/// Applies a linear combination operator followed by the GELU activation to an array of elements.
///
/// D = alpha * accumulator + beta * source + uniform
/// D = gelu(alpha * accumulator + beta * source + uniform)
///
template <
typename ElementOutput_, ///< Data type used to load and store tensors
Expand All @@ -57,153 +53,9 @@ template <
typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest
>
class LinearCombinationGELU {
public:

using ElementOutput = ElementOutput_;
using ElementAccumulator = ElementAccumulator_;
using ElementCompute = ElementCompute_;

static bool const kIsHeavy = true;

static int const kCount = Count;

using FragmentOutput = Array<ElementOutput, kCount>;
using FragmentAccumulator = Array<ElementAccumulator, kCount>;
using ComputeFragment = Array<ElementCompute, kCount>;

static FloatRoundStyle const kRound = Round;

/// Host-constructable parameters structure
struct Params {

ElementCompute alpha; ///< scales accumulators
ElementCompute beta; ///< scales source tensor
ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory
ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory

//
// Methods
//

CUTLASS_HOST_DEVICE
Params():
alpha(ElementCompute(1)),
beta(ElementCompute(0)),
alpha_ptr(nullptr),
beta_ptr(nullptr) { }

CUTLASS_HOST_DEVICE
Params(
ElementCompute alpha,
ElementCompute beta
): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) {

}

CUTLASS_HOST_DEVICE
Params(
ElementCompute const *alpha_ptr,
ElementCompute const *beta_ptr
): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) {

}
};

private:

//
// Data members
//

ElementCompute alpha_;
ElementCompute beta_;

public:

/// Constructs the function object, possibly loading from pointers in host memory
CUTLASS_HOST_DEVICE
LinearCombinationGELU(Params const &params) {

alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha);
beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta);
}

/// Returns true if source is needed
CUTLASS_HOST_DEVICE
bool is_source_needed() const {
return beta_ != ElementCompute(0);
}

/// Functionally required for serial reduction in the epilogue
CUTLASS_HOST_DEVICE
void set_k_partition(int k_partition, int k_partition_count) {
if (k_partition) {
beta_ = ElementCompute(1);
}

CUTLASS_UNUSED(k_partition_count);
}

/// Computes: D = gelu( alpha * accumulator + beta * source )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator,
FragmentOutput const &source) const {

// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementOutput, kCount, Round> source_converter;
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;

ComputeFragment converted_source = source_converter(source);
ComputeFragment converted_accumulator = accumulator_converter(accumulator);

// Perform binary operations

ComputeFragment intermediate;

multiplies<ComputeFragment> mul_add_source;
multiply_add<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;

intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform
intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X

intermediate = gelu(intermediate);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;

return destination_converter(intermediate);
}

/// Computes: D = gelu( alpha * accumulator )
CUTLASS_HOST_DEVICE
FragmentOutput operator()(
FragmentAccumulator const &accumulator) const {

// Convert source to interal compute numeric type
NumericArrayConverter<ElementCompute, ElementAccumulator, kCount, Round> accumulator_converter;

ComputeFragment converted_accumulator = accumulator_converter(accumulator);

// Perform binary operations

ComputeFragment intermediate;

multiplies<ComputeFragment> mul_add_accumulator;
GELU<ComputeFragment> gelu;

intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum

intermediate = gelu(intermediate);

// Convert to destination numeric type
NumericArrayConverter<ElementOutput, ElementCompute, kCount, Round> destination_converter;
using LinearCombinationGELU = LinearCombinationGeneric<GELU, ElementOutput_, Count, ElementAccumulator_,
ElementCompute_, Round, true>;

return destination_converter(intermediate);
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Loading