-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Test only] BFloat16 test for SkipSimplifiedLayerNormalization #22941
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -17,6 +17,8 @@ | |||||||||||||||||||||||||
import os | ||||||||||||||||||||||||||
import tempfile | ||||||||||||||||||||||||||
from typing import Dict | ||||||||||||||||||||||||||
from enum import Enum | ||||||||||||||||||||||||||
import ml_dtypes | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||
Comment on lines
18
to
23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
import onnx | ||||||||||||||||||||||||||
|
@@ -36,7 +38,6 @@ def _npfloat16_to_int(np_list): | |||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): | ||||||||||||||||||||||||||
Comment on lines
40
to
41
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
Convert float32 numpy array to float16 without changing sign or finiteness. | ||||||||||||||||||||||||||
|
@@ -107,6 +108,43 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit | |||||||||||||||||||||||||
tensor.raw_data = float16_list.tobytes() | ||||||||||||||||||||||||||
return tensor | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def convert_tensor_float_to_bfloat16(tensor): | ||||||||||||||||||||||||||
Comment on lines
110
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
"""Convert tensor float to bfloat16. | ||||||||||||||||||||||||||
Args: | ||||||||||||||||||||||||||
tensor (TensorProto): the tensor to convert. | ||||||||||||||||||||||||||
min_positive_val (float, optional): minimal positive value. Defaults to 1e-7. | ||||||||||||||||||||||||||
max_finite_val (float, optional): maximal finite value. Defaults to 1e4. | ||||||||||||||||||||||||||
Raises: | ||||||||||||||||||||||||||
ValueError: input type is not TensorProto. | ||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||
TensorProto: the converted tensor. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if not isinstance(tensor, TensorProto): | ||||||||||||||||||||||||||
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}") | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if tensor.data_type == TensorProto.FLOAT: | ||||||||||||||||||||||||||
tensor.data_type = TensorProto.BFLOAT16 | ||||||||||||||||||||||||||
# convert float_data (float type) to bfloat16 and write to int32_data | ||||||||||||||||||||||||||
if tensor.float_data: | ||||||||||||||||||||||||||
bfloat16_data = tensor.float_data.astype(ml_dtypes.bfloat16) | ||||||||||||||||||||||||||
# we can use _npfloat16_to_int here because float16 and bfloat16 are both 16-bits. | ||||||||||||||||||||||||||
int_list = _npfloat16_to_int(bfloat16_data) | ||||||||||||||||||||||||||
tensor.int32_data[:] = int_list | ||||||||||||||||||||||||||
tensor.float_data[:] = [] | ||||||||||||||||||||||||||
# convert raw_data (bytes type) | ||||||||||||||||||||||||||
if tensor.raw_data: | ||||||||||||||||||||||||||
# convert n.raw_data to float | ||||||||||||||||||||||||||
float32_list = np.frombuffer(tensor.raw_data, dtype="float32") | ||||||||||||||||||||||||||
# convert float to bfloat16 | ||||||||||||||||||||||||||
bfloat16_list = float32_list.astype(ml_dtypes.bfloat16) | ||||||||||||||||||||||||||
# convert bfloat16 to bytes and write back to raw_data | ||||||||||||||||||||||||||
tensor.raw_data = bfloat16_list.tobytes() | ||||||||||||||||||||||||||
return tensor | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def make_value_info_from_tensor(tensor): | ||||||||||||||||||||||||||
shape = numpy_helper.to_array(tensor).shape | ||||||||||||||||||||||||||
|
@@ -148,6 +186,10 @@ def make_value_info_from_tensor(tensor): | |||||||||||||||||||||||||
# Note that DirectML allows float16 gamma and beta in GroupNorm. Use force_fp16_inputs parameter could overwrite this. | ||||||||||||||||||||||||||
ALWAYS_FLOAT_INPUTS = {"Resize": [2], "GroupNorm": [1, 2], "SkipGroupNorm": [1, 2]} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class NodeValueType(Enum): | ||||||||||||||||||||||||||
FP32 = 1 | ||||||||||||||||||||||||||
Comment on lines
+189
to
+190
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
FP16 = 2 | ||||||||||||||||||||||||||
BF16 = 3 | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class InitializerTracker: | ||||||||||||||||||||||||||
"""Class for keeping track of initializer.""" | ||||||||||||||||||||||||||
Comment on lines
194
to
195
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
@@ -156,13 +198,15 @@ def __init__(self, initializer: TensorProto): | |||||||||||||||||||||||||
self.initializer = initializer | ||||||||||||||||||||||||||
self.fp32_nodes = [] | ||||||||||||||||||||||||||
self.fp16_nodes = [] | ||||||||||||||||||||||||||
self.bf16_nodes = [] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def add_node(self, node: NodeProto, is_node_blocked): | ||||||||||||||||||||||||||
if is_node_blocked: | ||||||||||||||||||||||||||
def add_node(self, node: NodeProto, node_value_type): | ||||||||||||||||||||||||||
if node_value_type == NodeValueType.FP32: | ||||||||||||||||||||||||||
self.fp32_nodes.append(node) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
elif node_value_type == NodeValueType.FP16: | ||||||||||||||||||||||||||
self.fp16_nodes.append(node) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
elif node_value_type == NodeValueType.BF16: | ||||||||||||||||||||||||||
self.bf16_nodes.append(node) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def convert_float_to_float16( | ||||||||||||||||||||||||||
model, | ||||||||||||||||||||||||||
Comment on lines
211
to
212
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
@@ -332,12 +376,17 @@ def convert_float_to_float16( | |||||||||||||||||||||||||
is_node_blocked = n.op_type in op_block_list or n.name in node_block_list | ||||||||||||||||||||||||||
for i, input_name in enumerate(n.input): | ||||||||||||||||||||||||||
if input_name in fp32_initializers: | ||||||||||||||||||||||||||
# For Resize/GroupNorm, only the first input can be float16 | ||||||||||||||||||||||||||
use_fp32_weight = is_node_blocked or ( | ||||||||||||||||||||||||||
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) | ||||||||||||||||||||||||||
and i not in force_fp16_inputs_dict.get(n.op_type, []) | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
fp32_initializers[input_name].add_node(n, use_fp32_weight) | ||||||||||||||||||||||||||
if is_node_blocked and use_bfloat16_as_blocked_nodes_dtype: | ||||||||||||||||||||||||||
fp32_initializers[input_name].add_node(n, NodeValueType.BF16) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
# For Resize/GroupNorm, only the first input can be float16 | ||||||||||||||||||||||||||
if is_node_blocked or ( | ||||||||||||||||||||||||||
i in ALWAYS_FLOAT_INPUTS.get(n.op_type, []) | ||||||||||||||||||||||||||
and i not in force_fp16_inputs_dict.get(n.op_type, []) | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
fp32_initializers[input_name].add_node(n, NodeValueType.FP32) | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
fp32_initializers[input_name].add_node(n, NodeValueType.FP16) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
if is_node_blocked: | ||||||||||||||||||||||||||
node_list.append(n) | ||||||||||||||||||||||||||
|
@@ -413,6 +462,10 @@ def convert_float_to_float16( | |||||||||||||||||||||||||
logger.info( | ||||||||||||||||||||||||||
f"initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{value.fp16_nodes}" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
if value.bf16_nodes: | ||||||||||||||||||||||||||
value.initializer = convert_tensor_float_to_bfloat16(value.initializer) | ||||||||||||||||||||||||||
value_info_list.append(make_value_info_from_tensor(value.initializer)) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Some operators have data type fixed as float for some input. Add a float16 to float cast for those inputs. | ||||||||||||||||||||||||||
for node in mixed_float_type_node_list: | ||||||||||||||||||||||||||
Comment on lines
469
to
471
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.