Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set bf16 black_list and white_list #55713

Merged
merged 5 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/paddle/amp/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
# The set of ops that support fp16 and bf16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16 or bf16.
WHITE_LIST = {
'conv2d',
'einsum',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fused_gemm_epilogue',
}

# The set of ops that support fp16, and bf16 was unsupported.
ONLY_FP16_WHITE_LIST = {
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
'fused_gemm_epilogue',
'fused_attention',
'fused_feedforward',
}

FP16_WHITE_LIST = WHITE_LIST | ONLY_FP16_WHITE_LIST

# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
Expand Down Expand Up @@ -90,8 +96,8 @@
'scatter',
}

BF16_WHITE_LIST = {'conv2d', 'einsum', 'matmul_v2'}
BF16_BLACK_LIST = set()
BF16_WHITE_LIST = WHITE_LIST
BF16_BLACK_LIST = FP16_BLACK_LIST
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

einsum这个应该不能移除,之前是专门加上去的。可以考虑给bf16和fp16的白名单求个并集统一一下,一些历史原因导致没有统一。不过需要注意的白名单中的某些融合算子可能是只支持低精度,或者某种低精度

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的



# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
Expand Down
28 changes: 3 additions & 25 deletions python/paddle/static/amp/bf16/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy

from paddle.amp.amp_lists import BF16_WHITE_LIST
from paddle.fluid import core

from ..fp16_lists import black_list as black_list_fp16
Expand Down Expand Up @@ -86,33 +87,10 @@ def _update_list(self):
bf16_initializer_list = {'fill_constant', 'uniform_random'}

# always bf16
bf16_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
bf16_list = BF16_WHITE_LIST

# depends on the prev_op type
gray_list = {
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'relu',
'layer_norm',
'slice',
'concat',
'uniform_random',
'reshape2',
'transpose2',
'pool2d',
'sigmoid',
'cast',
'scale',
'fill_constant',
'split',
}
gray_list = gray_list_fp16

_, _, _sys_unsupported_bf16_list = core.op_supported_infos(
'CPU', core.VarDesc.VarType.BF16
Expand Down
6 changes: 3 additions & 3 deletions test/contrib/test_bf16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def test_amp_lists_1(self):

def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh')
self.bf16_list.add('tanh')
self.fp32_list.remove('tan')
self.bf16_list.add('tan')

self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'})
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tan'})

def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None
Expand Down