Skip to content

Commit

Permalink
Update skip layer norm (#22719)
Browse files Browse the repository at this point in the history
Update the `SkipLayerNorm` implementation to address issues.
  • Loading branch information
amarin16 authored and guschmue committed Dec 2, 2024
1 parent 6088ae8 commit 5b7266d
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 83 deletions.
145 changes: 78 additions & 67 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,13 @@ void ComputeJob(
const T* gamma_data,
const T* beta_data,
const T* bias_data,
IAllocatorUniquePtr<float>& skip_float_uptr,
IAllocatorUniquePtr<float>& gamma_float_uptr,
IAllocatorUniquePtr<float>& beta_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
T* output_data,
T* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
ORT_UNUSED_PARAMETER(skip_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(gamma_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(beta_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(bias_float_uptr); // only used in MLFloat16 overload
ORT_UNUSED_PARAMETER(alloc);

T* skip_input_bias_add_output_data) {
auto offset = task_idx * hidden_size;
const T* p_input = input_data + offset;
const T* p_skip = skip_data + (offset % skip_size);
Expand Down Expand Up @@ -110,13 +99,11 @@ void ComputeJob(
void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const MLFloat16* gamma_data,
const MLFloat16* beta_data,
const MLFloat16* bias_data,
IAllocatorUniquePtr<float>& skip_float_uptr,
IAllocatorUniquePtr<float>& gamma_float_uptr,
IAllocatorUniquePtr<float>& beta_float_uptr,
IAllocatorUniquePtr<float>& bias_float_uptr,
const float* prepacked_skip_fp32_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* output_float_ptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
Expand All @@ -127,7 +114,6 @@ void ComputeJob(
AllocatorPtr alloc) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
const MLFloat16* p_skip = skip_data + (offset % skip_size);
MLFloat16* p_output = output_data + offset;
MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset;

Expand All @@ -138,26 +124,19 @@ void ComputeJob(
IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);

if (!skip_float_uptr) {
IAllocatorUniquePtr<float> skip_float_uptr = nullptr;
if (prepacked_skip_fp32_data == nullptr && skip_data) {
const MLFloat16* p_skip = skip_data + (offset % skip_size);
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
}

if (bias_data && !bias_float_uptr) {
bias_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_float_uptr.get(), num_elems);
}

IAllocatorUniquePtr<float> output_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
float* output_float_ptr = output_float_uptr.get();

const float* input_float_ptr = input_float_uptr.get();
const float* skip_float_ptr = skip_float_uptr.get();
const float* bias_float_ptr = bias_float_uptr.get();
const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
float val = input_float_ptr[h] + skip_float_ptr[h];

if (bias_float_uptr) {
if (bias_float_ptr) {
val += bias_float_ptr[h];
}

Expand All @@ -177,22 +156,10 @@ void ComputeJob(
mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon);
}

if (!gamma_float_uptr) {
gamma_float_uptr = std::move(input_float_uptr); // overwrite input with gamma values, since they have the same size
MlasConvertHalfToFloatBuffer(gamma_data, gamma_float_uptr.get(), num_elems);
}

if (beta_data && !beta_float_uptr) {
beta_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_float_uptr.get(), num_elems);
}

const float* gamma_float_ptr = gamma_float_uptr.get();
const float* beta_float_ptr = beta_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h];
} else if (nullptr == beta_float_uptr) {
} else if (nullptr == beta_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h];
Expand All @@ -218,18 +185,23 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I

template <typename T, bool simplified>
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info), skip_fp32_(nullptr), gamma_fp32_(nullptr), beta_fp32_(nullptr), bias_fp32_(nullptr) {
: OpKernel(op_kernel_info),
prepacked_skip_fp32_size_(0),
prepacked_skip_fp32_data_(nullptr),
prepacked_gamma_fp32_data_(nullptr),
prepacked_beta_fp32_data_(nullptr),
prepacked_bias_fp32_data_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}

template <typename T, bool simplified>
Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input<Tensor>(0);
const Tensor* skip = p_ctx->Input<Tensor>(1);
const Tensor* gamma = p_ctx->Input<Tensor>(2);
const Tensor* beta = p_ctx->Input<Tensor>(3);
const Tensor* bias = p_ctx->Input<Tensor>(4);
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
Tensor* output = p_ctx->Output(0, input->Shape());
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
Expand All @@ -238,19 +210,21 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
size_t input_dims_size = input_dims.size();
int hidden_size = static_cast<int>(input_dims[input_dims_size - 1]);

ORT_RETURN_IF_ERROR(onnxruntime::contrib::skip_layer_norm_helper::CheckInputs<Tensor>(input,
skip,
gamma,
beta,
bias,
hidden_size,
input_dims_size));
ORT_RETURN_IF_ERROR(skip_layer_norm_helper::CheckPotentiallyPrepackedInputs<Tensor>(input,
skip,
gamma,
beta,
bias,
hidden_size,
input_dims_size,
prepacked_skip_fp32_data_ != nullptr,
prepacked_gamma_fp32_data_ != nullptr));

int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);

const T* input_data = input->Data<T>();
const T* skip_data = skip->Data<T>();
const T* gamma_data = gamma->Data<T>();
const T* skip_data = skip == nullptr ? nullptr : skip->Data<T>();
const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data<T>();
const T* beta_data = beta == nullptr ? nullptr : beta->Data<T>();
const T* bias_data = bias == nullptr ? nullptr : bias->Data<T>();

Expand All @@ -259,17 +233,53 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData<T>();

const int64_t& skip_size = skip->Shape().Size();
const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_;

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> output_fp32;
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

if constexpr (std::is_same_v<T, MLFloat16>) {
const size_t num_elems = static_cast<size_t>(hidden_size);

output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);

if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
}

if (prepacked_beta_fp32_data_ == nullptr && beta_data) {
beta_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems);
}

if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, skip_fp32_, gamma_fp32_, beta_fp32_,
bias_fp32_, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
if constexpr (std::is_same_v<T, MLFloat16>) {
ComputeJob(input_data, skip_data,
prepacked_skip_fp32_data_.get(),
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
output_fp32.get(),
task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data, alloc);
} else {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
}
},
0);

Expand All @@ -283,13 +293,14 @@ Status SkipLayerNorm<T, simplified>::PrePack(const Tensor& tensor, int input_idx

is_packed = false;
if (input_idx == 1) { // skip
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, skip_fp32_, is_packed);
prepacked_skip_fp32_size_ = tensor.Shape().Size();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed);
} else if (input_idx == 2) { // gamma
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, gamma_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed);
} else if (input_idx == 3) { // beta
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, beta_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed);
} else if (input_idx == 4) { // bias
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, bias_fp32_, is_packed);
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

return Status::OK();
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ class SkipLayerNorm final : public OpKernel {

private:
float epsilon_;
mutable IAllocatorUniquePtr<float> skip_fp32_;
mutable IAllocatorUniquePtr<float> gamma_fp32_;
mutable IAllocatorUniquePtr<float> beta_fp32_;
mutable IAllocatorUniquePtr<float> bias_fp32_;
int64_t prepacked_skip_fp32_size_;
IAllocatorUniquePtr<float> prepacked_skip_fp32_data_;
IAllocatorUniquePtr<float> prepacked_gamma_fp32_data_;
IAllocatorUniquePtr<float> prepacked_beta_fp32_data_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
};

} // namespace contrib
Expand Down
Loading

0 comments on commit 5b7266d

Please sign in to comment.