Skip to content

Commit

Permalink
BFloat16 test for SkipSimplifiedLayerNormalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jiafatom committed Nov 25, 2024
1 parent ff57ac4 commit e5f20f2
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 24 deletions.
35 changes: 35 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cublas_v2.h>
#include <cub/cub.cuh>

Expand Down Expand Up @@ -60,6 +61,24 @@ __device__ inline half2 AddHalf2(const half2 a, const half2 b) {
#endif
}

template <>
__device__ inline nv_bfloat16 Rsqrt(const nv_bfloat16& x) {
#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
return hrsqrt(x);
#else
return nv_bfloat16(rsqrtf(float(x)));
#endif
}

__device__ inline nv_bfloat162 AddHalf2(const nv_bfloat162 a, const nv_bfloat162 b) {
return a;
#if defined(__CUDACC__) && (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
return __hadd2(a, b);
#else
return __halves2bfloat162(__hadd(a.x, b.x), __hadd(a.y, b.y));
#endif
}

struct KeyValuePairSum {
__device__ inline cub::KeyValuePair<float, float> operator()(const cub::KeyValuePair<float, float>& a,
const cub::KeyValuePair<float, float>& b) {
Expand All @@ -78,6 +97,22 @@ struct KeyValuePairSum {
const cub::KeyValuePair<half2, half2>& b) {
return cub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}


__device__ inline cub::KeyValuePair<nv_bfloat16, nv_bfloat16> operator()(const cub::KeyValuePair<nv_bfloat16, nv_bfloat16>& a,
const cub::KeyValuePair<nv_bfloat16, nv_bfloat16>& b) {
const nv_bfloat162 a2 = __halves2bfloat162(a.key, a.value);
const nv_bfloat162 b2 = __halves2bfloat162(b.key, b.value);
const nv_bfloat162 res = AddHalf2(a2, b2);
return cub::KeyValuePair<nv_bfloat16, nv_bfloat16>(__low2bfloat16(res), __high2bfloat16(res));

}

__device__ inline cub::KeyValuePair<nv_bfloat162, nv_bfloat162> operator()(const cub::KeyValuePair<nv_bfloat162, nv_bfloat162>& a,
const cub::KeyValuePair<nv_bfloat162, nv_bfloat162>& b) {
return a;
// return cub::KeyValuePair<nv_bfloat162, nv_bfloat162>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}
};

template <typename T, int TPB>
Expand Down
45 changes: 32 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

using namespace ONNX_NAMESPACE;

Expand Down Expand Up @@ -105,19 +106,37 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
} else {
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_size);
if (std::is_same<T, BFloat16>::value) {
LaunchSkipLayerNormKernel<nv_bfloat16, Simplified>(
Stream(ctx),
reinterpret_cast<nv_bfloat16*>(output->MutableData<T>()),
sum_output != nullptr ? reinterpret_cast<nv_bfloat16*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const nv_bfloat16*>(input->Data<T>()),
reinterpret_cast<const nv_bfloat16*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const nv_bfloat16*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const nv_bfloat16*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const nv_bfloat16*>(beta->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_size);
}
else
{
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_size);
}
}

CUDA_RETURN_IF_ERROR(cudaGetLastError());
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/layer_norm.cuh"
#include "contrib_ops/cuda/bert/skip_layer_norm_impl.h"
#include <cuda_fp16.h>
#include <cuda_bf16.h>

namespace onnxruntime {
namespace contrib {
Expand All @@ -49,6 +50,11 @@ half maybe2half(float x) {
return __float2half_rn(x);
}

template <>
nv_bfloat16 maybe2half(float x) {
return __float2bfloat16_rn(x);
}

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192};
Expand Down Expand Up @@ -263,6 +269,8 @@ SKIPLAYERNORM_IMPL(float, true);
SKIPLAYERNORM_IMPL(float, false);
SKIPLAYERNORM_IMPL(half, true);
SKIPLAYERNORM_IMPL(half, false);
SKIPLAYERNORM_IMPL(nv_bfloat16, true);
SKIPLAYERNORM_IMPL(nv_bfloat16, false);

} // namespace cuda
} // namespace contrib
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu);
Expand Down Expand Up @@ -327,6 +328,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu)>,
Expand Down
75 changes: 64 additions & 11 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import os
import tempfile
from typing import Dict
from enum import Enum
import ml_dtypes

import numpy as np
import onnx
Expand All @@ -36,7 +38,6 @@ def _npfloat16_to_int(np_list):
"""
return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list]


def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0):
"""
Convert float32 numpy array to float16 without changing sign or finiteness.
Expand Down Expand Up @@ -107,6 +108,43 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit
tensor.raw_data = float16_list.tobytes()
return tensor

def convert_tensor_float_to_bfloat16(tensor):
"""Convert tensor float to bfloat16.
Args:
tensor (TensorProto): the tensor to convert.
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
Raises:
ValueError: input type is not TensorProto.
Returns:
TensorProto: the converted tensor.
"""

if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")

if tensor.data_type == TensorProto.FLOAT:
tensor.data_type = TensorProto.BFLOAT16
# convert float_data (float type) to bfloat16 and write to int32_data
if tensor.float_data:
bfloat16_data = tensor.float_data.astype(ml_dtypes.bfloat16)
# we can use _npfloat16_to_int here because float16 and bfloat16 are both 16-bits.
int_list = _npfloat16_to_int(bfloat16_data)
tensor.int32_data[:] = int_list
tensor.float_data[:] = []
# convert raw_data (bytes type)
if tensor.raw_data:
# convert n.raw_data to float
float32_list = np.frombuffer(tensor.raw_data, dtype="float32")
# convert float to bfloat16
bfloat16_list = float32_list.astype(ml_dtypes.bfloat16)
# convert bfloat16 to bytes and write back to raw_data
tensor.raw_data = bfloat16_list.tobytes()
return tensor


def make_value_info_from_tensor(tensor):
shape = numpy_helper.to_array(tensor).shape
Expand Down Expand Up @@ -148,6 +186,10 @@ def make_value_info_from_tensor(tensor):
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}

class NodeValueType(Enum):
FP32 = 1
FP16 = 2
BF16 = 3

class InitializerTracker:
"""Class for keeping track of initializer."""
Expand All @@ -156,13 +198,15 @@ def __init__(self, initializer: TensorProto):
self.initializer = initializer
self.fp32_nodes = []
self.fp16_nodes = []
self.bf16_nodes = []

def add_node(self, node: NodeProto, is_node_blocked):
if is_node_blocked:
def add_node(self, node: NodeProto, node_value_type):
if node_value_type == NodeValueType.FP32:
self.fp32_nodes.append(node)
else:
elif node_value_type == NodeValueType.FP16:
self.fp16_nodes.append(node)

elif node_value_type == NodeValueType.BF16:
self.bf16_nodes.append(node)

def convert_float_to_float16(
model,
Expand Down Expand Up @@ -332,12 +376,17 @@ def convert_float_to_float16(
is_node_blocked = n.op_type in op_block_list or n.name in node_block_list
for i, input_name in enumerate(n.input):
if input_name in fp32_initializers:
# For Resize/GroupNorm, only the first input can be float16
use_fp32_weight = is_node_blocked or (
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
and i not in force_fp16_inputs_dict.get(n.op_type, [])
)
fp32_initializers[input_name].add_node(n, use_fp32_weight)
if is_node_blocked and use_bfloat16_as_blocked_nodes_dtype:
fp32_initializers[input_name].add_node(n, NodeValueType.BF16)
else:
# For Resize/GroupNorm, only the first input can be float16
if is_node_blocked or (
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
and i not in force_fp16_inputs_dict.get(n.op_type, [])
):
fp32_initializers[input_name].add_node(n, NodeValueType.FP32)
else:
fp32_initializers[input_name].add_node(n, NodeValueType.FP16)

if is_node_blocked:
node_list.append(n)
Expand Down Expand Up @@ -413,6 +462,10 @@ def convert_float_to_float16(
logger.info(
f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}"
)
if value.bf16_nodes:
value.initializer = convert_tensor_float_to_bfloat16(value.initializer)
value_info_list.append(make_value_info_from_tensor(value.initializer))


# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in mixed_float_type_node_list:
Expand Down

0 comments on commit e5f20f2

Please sign in to comment.