From 03bf8396867691e59c373bcfccf26ed9664dcbcf Mon Sep 17 00:00:00 2001 From: David Fan Date: Mon, 25 Nov 2024 17:40:55 +0000 Subject: [PATCH] BFloat16 test for SkipSimplifiedLayerNormalization --- .../contrib_ops/cuda/bert/layer_norm.cuh | 24 ++++++ .../contrib_ops/cuda/bert/skip_layer_norm.cc | 45 +++++++---- .../cuda/bert/skip_layer_norm_impl.cu | 8 ++ .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 + .../core/graph/contrib_ops/bert_defs.cc | 2 +- .../python/tools/transformers/float16.py | 75 ++++++++++++++++--- 6 files changed, 131 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh index ff3178b56c2a6..a087d0d546ecb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh @@ -25,6 +25,7 @@ limitations under the License. #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_call.h" #include +#include #include #include @@ -60,6 +61,15 @@ __device__ inline half2 AddHalf2(const half2 a, const half2 b) { #endif } +template <> +__device__ inline nv_bfloat16 Rsqrt(const nv_bfloat16& x) { + return hrsqrt(x); +} + +__device__ inline nv_bfloat162 AddHalf2(const nv_bfloat162 a, const nv_bfloat162 b) { + return __hadd2(a, b); +} + struct KeyValuePairSum { __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, const cub::KeyValuePair& b) { @@ -78,6 +88,20 @@ struct KeyValuePairSum { const cub::KeyValuePair& b) { return cub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); } + + + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + const nv_bfloat162 a2 = __halves2bfloat162(a.key, a.value); + const nv_bfloat162 b2 = __halves2bfloat162(b.key, b.value); + const nv_bfloat162 res = AddHalf2(a2, b2); + return cub::KeyValuePair(__low2bfloat16(res), __high2bfloat16(res)); + } + + __device__ inline cub::KeyValuePair operator()(const cub::KeyValuePair& a, + const cub::KeyValuePair& b) { + return cub::KeyValuePair(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value)); + } }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 3299bc2cb11de..6e17d18870bc6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -34,6 +34,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) using namespace ONNX_NAMESPACE; @@ -105,19 +106,37 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, // bias to add sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr); } else { - LaunchSkipLayerNormKernel( - Stream(ctx), - reinterpret_cast(output->MutableData()), - sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, - reinterpret_cast(input->Data()), - reinterpret_cast(skip->Data()), - (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, - reinterpret_cast(gamma->Data()), - (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, - epsilon_, - hidden_size, - row_count, - skip_size); + if (std::is_same::value) { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, + skip_size); + } + else + { + LaunchSkipLayerNormKernel( + Stream(ctx), + reinterpret_cast(output->MutableData()), + sum_output != nullptr ? reinterpret_cast(sum_output->MutableData()) : nullptr, + reinterpret_cast(input->Data()), + reinterpret_cast(skip->Data()), + (bias != nullptr) ? reinterpret_cast(bias->Data()) : nullptr, + reinterpret_cast(gamma->Data()), + (beta != nullptr) ? reinterpret_cast(beta->Data()) : nullptr, + epsilon_, + hidden_size, + row_count, + skip_size); + } } CUDA_RETURN_IF_ERROR(cudaGetLastError()); diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu index 50c8e4b5e0398..caf46ad8a198f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu @@ -30,6 +30,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/layer_norm.cuh" #include "contrib_ops/cuda/bert/skip_layer_norm_impl.h" #include +#include namespace onnxruntime { namespace contrib { @@ -49,6 +50,11 @@ half maybe2half(float x) { return __float2half_rn(x); } +template <> +nv_bfloat16 maybe2half(float x) { + return __float2bfloat16_rn(x); +} + // Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case // in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time. constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192}; @@ -263,6 +269,8 @@ SKIPLAYERNORM_IMPL(float, true); SKIPLAYERNORM_IMPL(float, false); SKIPLAYERNORM_IMPL(half, true); SKIPLAYERNORM_IMPL(half, false); +SKIPLAYERNORM_IMPL(nv_bfloat16, true); +SKIPLAYERNORM_IMPL(nv_bfloat16, false); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 21bd5eb91c20f..845af598f6ec7 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -123,6 +123,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu); @@ -327,6 +328,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index f2a2a52f8334f..348b7281cd042 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1611,7 +1611,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "with shape (batch_size, sequence_length, hidden_size) or (token_count, hidden_size).", "T", OpSchema::Optional) - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float, half, bfloat16 tensors.") .TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.") .TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference)); diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 2398bb9d6031b..fa62be0751fc7 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -17,6 +17,8 @@ import os import tempfile from typing import Dict +from enum import Enum +import ml_dtypes import numpy as np import onnx @@ -36,7 +38,6 @@ def _npfloat16_to_int(np_list): """ return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list] - def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): """ Convert float32 numpy array to float16 without changing sign or finiteness. @@ -107,6 +108,43 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit tensor.raw_data = float16_list.tobytes() return tensor +def convert_tensor_float_to_bfloat16(tensor): + """Convert tensor float to bfloat16. + + Args: + tensor (TensorProto): the tensor to convert. + min_positive_val (float, optional): minimal positive value. Defaults to 1e-7. + max_finite_val (float, optional): maximal finite value. Defaults to 1e4. + + Raises: + ValueError: input type is not TensorProto. + + Returns: + TensorProto: the converted tensor. + """ + + if not isinstance(tensor, TensorProto): + raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}") + + if tensor.data_type == TensorProto.FLOAT: + tensor.data_type = TensorProto.BFLOAT16 + # convert float_data (float type) to bfloat16 and write to int32_data + if tensor.float_data: + bfloat16_data = tensor.float_data.astype(ml_dtypes.bfloat16) + # we can use _npfloat16_to_int here because float16 and bfloat16 are both 16-bits. + int_list = _npfloat16_to_int(bfloat16_data) + tensor.int32_data[:] = int_list + tensor.float_data[:] = [] + # convert raw_data (bytes type) + if tensor.raw_data: + # convert n.raw_data to float + float32_list = np.frombuffer(tensor.raw_data, dtype="float32") + # convert float to bfloat16 + bfloat16_list = float32_list.astype(ml_dtypes.bfloat16) + # convert bfloat16 to bytes and write back to raw_data + tensor.raw_data = bfloat16_list.tobytes() + return tensor + def make_value_info_from_tensor(tensor): shape = numpy_helper.to_array(tensor).shape @@ -148,6 +186,10 @@ def make_value_info_from_tensor(tensor): # Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]} +class NodeValueType(Enum): + FP32 = 1 + FP16 = 2 + BF16 = 3 class InitializerTracker: """Class for keeping track of initializer.""" @@ -156,13 +198,15 @@ def __init__(self, initializer: TensorProto): self.initializer = initializer self.fp32_nodes = [] self.fp16_nodes = [] + self.bf16_nodes = [] - def add_node(self, node: NodeProto, is_node_blocked): - if is_node_blocked: + def add_node(self, node: NodeProto, node_value_type): + if node_value_type == NodeValueType.FP32: self.fp32_nodes.append(node) - else: + elif node_value_type == NodeValueType.FP16: self.fp16_nodes.append(node) - + elif node_value_type == NodeValueType.BF16: + self.bf16_nodes.append(node) def convert_float_to_float16( model, @@ -332,12 +376,17 @@ def convert_float_to_float16( is_node_blocked = n.op_type in op_block_list or n.name in node_block_list for i, input_name in enumerate(n.input): if input_name in fp32_initializers: - # For Resize/GroupNorm, only the first input can be float16 - use_fp32_weight = is_node_blocked or ( - i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) - and i not in force_fp16_inputs_dict.get(n.op_type, []) - ) - fp32_initializers[input_name].add_node(n, use_fp32_weight) + if is_node_blocked and use_bfloat16_as_blocked_nodes_dtype: + fp32_initializers[input_name].add_node(n, NodeValueType.BF16) + else: + # For Resize/GroupNorm, only the first input can be float16 + if is_node_blocked or ( + i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) + and i not in force_fp16_inputs_dict.get(n.op_type, []) + ): + fp32_initializers[input_name].add_node(n, NodeValueType.FP32) + else: + fp32_initializers[input_name].add_node(n, NodeValueType.FP16) if is_node_blocked: node_list.append(n) @@ -413,6 +462,10 @@ def convert_float_to_float16( logger.info( f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}" ) + if value.bf16_nodes: + value.initializer = convert_tensor_float_to_bfloat16(value.initializer) + value_info_list.append(make_value_info_from_tensor(value.initializer)) + # Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. for node in mixed_float_type_node_list: