Skip to content

Commit

Permalink
Set bf16 black_list and white_list (PaddlePaddle#55713)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored and BeingGod committed Sep 9, 2023
1 parent c368c8b commit 99be01a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 34 deletions.
18 changes: 12 additions & 6 deletions python/paddle/amp/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
# The set of ops that support fp16 and bf16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16 or bf16.
WHITE_LIST = {
'conv2d',
'einsum',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fused_gemm_epilogue',
}

# The set of ops that support fp16, and bf16 was unsupported.
ONLY_FP16_WHITE_LIST = {
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
'fused_gemm_epilogue',
'fused_attention',
'fused_feedforward',
}

FP16_WHITE_LIST = WHITE_LIST | ONLY_FP16_WHITE_LIST

# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
Expand Down Expand Up @@ -90,8 +96,8 @@
'scatter',
}

BF16_WHITE_LIST = {'conv2d', 'einsum', 'matmul_v2'}
BF16_BLACK_LIST = set()
BF16_WHITE_LIST = WHITE_LIST
BF16_BLACK_LIST = FP16_BLACK_LIST


# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
Expand Down
28 changes: 3 additions & 25 deletions python/paddle/static/amp/bf16/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy

from paddle.amp.amp_lists import BF16_WHITE_LIST
from paddle.fluid import core

from ..fp16_lists import black_list as black_list_fp16
Expand Down Expand Up @@ -86,33 +87,10 @@ def _update_list(self):
bf16_initializer_list = {'fill_constant', 'uniform_random'}

# always bf16
bf16_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
bf16_list = BF16_WHITE_LIST

# depends on the prev_op type
gray_list = {
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'relu',
'layer_norm',
'slice',
'concat',
'uniform_random',
'reshape2',
'transpose2',
'pool2d',
'sigmoid',
'cast',
'scale',
'fill_constant',
'split',
}
gray_list = gray_list_fp16

_, _, _sys_unsupported_bf16_list = core.op_supported_infos(
'CPU', core.VarDesc.VarType.BF16
Expand Down
6 changes: 3 additions & 3 deletions test/contrib/test_bf16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def test_amp_lists_1(self):

def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh')
self.bf16_list.add('tanh')
self.fp32_list.remove('tan')
self.bf16_list.add('tan')

self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'})
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tan'})

def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None
Expand Down

0 comments on commit 99be01a

Please sign in to comment.