Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Test only] BFloat16 test for SkipSimplifiedLayerNormalization #22941

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include <cuda_fp16.h>
#include <cuda_bf16.h>

Check warning on line 28 in onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: layer_norm.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/layer_norm.cuh:28: Found C system header after other header. Should be: layer_norm.h, c system, c++ system, other. [build/include_order] [4]
#include <cublas_v2.h>
#include <cub/cub.cuh>

Expand Down Expand Up @@ -60,6 +61,15 @@
#endif
}

template <>
__device__ inline nv_bfloat16 Rsqrt(const nv_bfloat16& x) {
return hrsqrt(x);
}

__device__ inline nv_bfloat162 AddHalf2(const nv_bfloat162 a, const nv_bfloat162 b) {
return __hadd2(a, b);
}

struct KeyValuePairSum {
__device__ inline cub::KeyValuePair<float, float> operator()(const cub::KeyValuePair<float, float>& a,
const cub::KeyValuePair<float, float>& b) {
Expand All @@ -78,6 +88,20 @@
const cub::KeyValuePair<half2, half2>& b) {
return cub::KeyValuePair<half2, half2>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}


__device__ inline cub::KeyValuePair<nv_bfloat16, nv_bfloat16> operator()(const cub::KeyValuePair<nv_bfloat16, nv_bfloat16>& a,
const cub::KeyValuePair<nv_bfloat16, nv_bfloat16>& b) {
const nv_bfloat162 a2 = __halves2bfloat162(a.key, a.value);
const nv_bfloat162 b2 = __halves2bfloat162(b.key, b.value);
const nv_bfloat162 res = AddHalf2(a2, b2);
return cub::KeyValuePair<nv_bfloat16, nv_bfloat16>(__low2bfloat16(res), __high2bfloat16(res));
}

__device__ inline cub::KeyValuePair<nv_bfloat162, nv_bfloat162> operator()(const cub::KeyValuePair<nv_bfloat162, nv_bfloat162>& a,
const cub::KeyValuePair<nv_bfloat162, nv_bfloat162>& b) {
return cub::KeyValuePair<nv_bfloat162, nv_bfloat162>(AddHalf2(a.key, b.key), AddHalf2(a.value, b.value));
}
};

template <typename T, int TPB>
Expand Down
45 changes: 32 additions & 13 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace cuda {

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)

using namespace ONNX_NAMESPACE;

Expand Down Expand Up @@ -105,19 +106,37 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
} else {
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_size);
if (std::is_same<T, BFloat16>::value) {
LaunchSkipLayerNormKernel<nv_bfloat16, Simplified>(
Stream(ctx),
reinterpret_cast<nv_bfloat16*>(output->MutableData<T>()),
sum_output != nullptr ? reinterpret_cast<nv_bfloat16*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const nv_bfloat16*>(input->Data<T>()),
reinterpret_cast<const nv_bfloat16*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const nv_bfloat16*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const nv_bfloat16*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const nv_bfloat16*>(beta->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_size);
}
else
{
LaunchSkipLayerNormKernel<CudaT, Simplified>(
Comment on lines +122 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
skip_size);
}
else
{
LaunchSkipLayerNormKernel<CudaT, Simplified>(
skip_size);
} else {
LaunchSkipLayerNormKernel<CudaT, Simplified>(

Stream(ctx),
reinterpret_cast<CudaT*>(output->MutableData<T>()),
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr,
reinterpret_cast<const CudaT*>(input->Data<T>()),
reinterpret_cast<const CudaT*>(skip->Data<T>()),
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr,
reinterpret_cast<const CudaT*>(gamma->Data<T>()),
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr,
epsilon_,
hidden_size,
row_count,
skip_size);
}
}

CUDA_RETURN_IF_ERROR(cudaGetLastError());
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "contrib_ops/cuda/bert/layer_norm.cuh"
#include "contrib_ops/cuda/bert/skip_layer_norm_impl.h"
#include <cuda_fp16.h>
#include <cuda_bf16.h>

Check warning on line 33 in onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after other header. Should be: skip_layer_norm_impl.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu:33: Found C system header after other header. Should be: skip_layer_norm_impl.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {
namespace contrib {
Expand All @@ -49,6 +50,11 @@
return __float2half_rn(x);
}

template <>
nv_bfloat16 maybe2half(float x) {
return __float2bfloat16_rn(x);
}

// Using only power of 2 numbers will lead to waste of compute for same size such as 768, which is a very common case
// in BERT. Ideally we can step by wrap_size * num_unroll, but listing too many steps will cause long compile time.
constexpr int kSizes[] = {128, 320, 384, 640, 768, 1024, 1280, 2048, 4096, 5120, 8192};
Expand Down Expand Up @@ -263,6 +269,8 @@
SKIPLAYERNORM_IMPL(float, false);
SKIPLAYERNORM_IMPL(half, true);
SKIPLAYERNORM_IMPL(half, false);
SKIPLAYERNORM_IMPL(nv_bfloat16, true);
SKIPLAYERNORM_IMPL(nv_bfloat16, false);

} // namespace cuda
} // namespace contrib
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization);
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu);
class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu);
Expand Down Expand Up @@ -327,6 +328,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, SkipSimplifiedLayerNormalization)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, float, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, ThresholdedRelu)>,
BuildKernelCreateInfo<CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, ThresholdedRelu)>,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"with shape (batch_size, sequence_length, hidden_size) or (token_count, hidden_size).",
"T",
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or half tensors.")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float, half, bfloat16 tensors.")
.TypeConstraint("U", {"tensor(float)"}, "Constrain mean and inv_std_var to float tensors.")
.TypeAndShapeInferenceFunction(SkipLayerNormalizationShapeInference));

Expand Down
75 changes: 64 additions & 11 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import os
import tempfile
from typing import Dict
from enum import Enum
import ml_dtypes

import numpy as np
Comment on lines 18 to 23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import tempfile
from typing import Dict
from enum import Enum
import ml_dtypes
import numpy as np
import tempfile
from enum import Enum
from typing import Dict
import ml_dtypes
import numpy as np

import onnx
Expand All @@ -36,7 +38,6 @@ def _npfloat16_to_int(np_list):
"""
return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list]


def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0):
Comment on lines 40 to 41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0):
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0):

"""
Convert float32 numpy array to float16 without changing sign or finiteness.
Expand Down Expand Up @@ -107,6 +108,43 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit
tensor.raw_data = float16_list.tobytes()
return tensor

def convert_tensor_float_to_bfloat16(tensor):
Comment on lines 110 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def convert_tensor_float_to_bfloat16(tensor):
def convert_tensor_float_to_bfloat16(tensor):

"""Convert tensor float to bfloat16.
Args:
tensor (TensorProto): the tensor to convert.
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
Raises:
ValueError: input type is not TensorProto.
Returns:
TensorProto: the converted tensor.
"""

if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")

if tensor.data_type == TensorProto.FLOAT:
tensor.data_type = TensorProto.BFLOAT16
# convert float_data (float type) to bfloat16 and write to int32_data
if tensor.float_data:
bfloat16_data = tensor.float_data.astype(ml_dtypes.bfloat16)
# we can use _npfloat16_to_int here because float16 and bfloat16 are both 16-bits.
int_list = _npfloat16_to_int(bfloat16_data)
tensor.int32_data[:] = int_list
tensor.float_data[:] = []
# convert raw_data (bytes type)
if tensor.raw_data:
# convert n.raw_data to float
float32_list = np.frombuffer(tensor.raw_data, dtype="float32")
# convert float to bfloat16
bfloat16_list = float32_list.astype(ml_dtypes.bfloat16)
# convert bfloat16 to bytes and write back to raw_data
tensor.raw_data = bfloat16_list.tobytes()
return tensor


def make_value_info_from_tensor(tensor):
shape = numpy_helper.to_array(tensor).shape
Expand Down Expand Up @@ -148,6 +186,10 @@ def make_value_info_from_tensor(tensor):
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this.
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]}

class NodeValueType(Enum):
FP32 = 1
Comment on lines +189 to +190
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class NodeValueType(Enum):
FP32 = 1
class NodeValueType(Enum):

FP16 = 2
BF16 = 3

class InitializerTracker:
"""Class for keeping track of initializer."""
Comment on lines 194 to 195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class InitializerTracker:
"""Class for keeping track of initializer."""
class InitializerTracker:

Expand All @@ -156,13 +198,15 @@ def __init__(self, initializer: TensorProto):
self.initializer = initializer
self.fp32_nodes = []
self.fp16_nodes = []
self.bf16_nodes = []

def add_node(self, node: NodeProto, is_node_blocked):
if is_node_blocked:
def add_node(self, node: NodeProto, node_value_type):
if node_value_type == NodeValueType.FP32:
self.fp32_nodes.append(node)
else:
elif node_value_type == NodeValueType.FP16:
self.fp16_nodes.append(node)

elif node_value_type == NodeValueType.BF16:
self.bf16_nodes.append(node)

def convert_float_to_float16(
model,
Comment on lines 211 to 212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def convert_float_to_float16(
model,
def convert_float_to_float16(

Expand Down Expand Up @@ -332,12 +376,17 @@ def convert_float_to_float16(
is_node_blocked = n.op_type in op_block_list or n.name in node_block_list
for i, input_name in enumerate(n.input):
if input_name in fp32_initializers:
# For Resize/GroupNorm, only the first input can be float16
use_fp32_weight = is_node_blocked or (
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
and i not in force_fp16_inputs_dict.get(n.op_type, [])
)
fp32_initializers[input_name].add_node(n, use_fp32_weight)
if is_node_blocked and use_bfloat16_as_blocked_nodes_dtype:
fp32_initializers[input_name].add_node(n, NodeValueType.BF16)
else:
# For Resize/GroupNorm, only the first input can be float16
if is_node_blocked or (
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, [])
and i not in force_fp16_inputs_dict.get(n.op_type, [])
):
fp32_initializers[input_name].add_node(n, NodeValueType.FP32)
else:
fp32_initializers[input_name].add_node(n, NodeValueType.FP16)

if is_node_blocked:
node_list.append(n)
Expand Down Expand Up @@ -413,6 +462,10 @@ def convert_float_to_float16(
logger.info(
f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}"
)
if value.bf16_nodes:
value.initializer = convert_tensor_float_to_bfloat16(value.initializer)
value_info_list.append(make_value_info_from_tensor(value.initializer))


# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in mixed_float_type_node_list:
Comment on lines 469 to 471
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.
for node in mixed_float_type_node_list:
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs.

Expand Down
Loading