diff --git a/python/paddle/amp/amp_lists.py b/python/paddle/amp/amp_lists.py index 3aefd298340a53..7d014b1bf14f92 100644 --- a/python/paddle/amp/amp_lists.py +++ b/python/paddle/amp/amp_lists.py @@ -12,22 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -# The set of ops that support fp16 calculation and are considered numerically- -# safe and performance-critical. These ops are always converted to fp16. -FP16_WHITE_LIST = { +# The set of ops that support fp16 and bf16 calculation and are considered numerically- +# safe and performance-critical. These ops are always converted to fp16 or bf16. +WHITE_LIST = { 'conv2d', 'einsum', 'matmul', 'matmul_v2', 'max_pool2d_with_index', 'mul', + 'fused_gemm_epilogue', +} + +# The set of ops that support fp16, and bf16 was unsupported. +ONLY_FP16_WHITE_LIST = { 'fake_quantize_dequantize_abs_max', 'fake_quantize_dequantize_moving_average_abs_max', - 'fused_gemm_epilogue', 'fused_attention', 'fused_feedforward', } +FP16_WHITE_LIST = WHITE_LIST | ONLY_FP16_WHITE_LIST + # The set of ops that support fp16 calculation and are considered numerically- # dangerous and whose effects may also be observed in downstream ops. FP16_BLACK_LIST = { @@ -90,8 +96,8 @@ 'scatter', } -BF16_WHITE_LIST = {'conv2d', 'einsum', 'matmul_v2'} -BF16_BLACK_LIST = set() +BF16_WHITE_LIST = WHITE_LIST +BF16_BLACK_LIST = FP16_BLACK_LIST # At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32. diff --git a/python/paddle/static/amp/bf16/amp_lists.py b/python/paddle/static/amp/bf16/amp_lists.py index 5ea5beb708b894..cd4d6bdb329b40 100644 --- a/python/paddle/static/amp/bf16/amp_lists.py +++ b/python/paddle/static/amp/bf16/amp_lists.py @@ -14,6 +14,7 @@ import copy +from paddle.amp.amp_lists import BF16_WHITE_LIST from paddle.fluid import core from ..fp16_lists import black_list as black_list_fp16 @@ -86,33 +87,10 @@ def _update_list(self): bf16_initializer_list = {'fill_constant', 'uniform_random'} # always bf16 -bf16_list = { - 'conv2d', - 'matmul', - 'matmul_v2', - 'mul', -} +bf16_list = BF16_WHITE_LIST # depends on the prev_op type -gray_list = { - 'elementwise_add', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_div', - 'relu', - 'layer_norm', - 'slice', - 'concat', - 'uniform_random', - 'reshape2', - 'transpose2', - 'pool2d', - 'sigmoid', - 'cast', - 'scale', - 'fill_constant', - 'split', -} +gray_list = gray_list_fp16 _, _, _sys_unsupported_bf16_list = core.op_supported_infos( 'CPU', core.VarDesc.VarType.BF16 diff --git a/test/contrib/test_bf16_utils.py b/test/contrib/test_bf16_utils.py index c44e5a4a97481c..75ce0045b39abf 100644 --- a/test/contrib/test_bf16_utils.py +++ b/test/contrib/test_bf16_utils.py @@ -46,10 +46,10 @@ def test_amp_lists_1(self): def test_amp_lists_2(self): # 2. w={'tanh'}, b=None - self.fp32_list.remove('tanh') - self.bf16_list.add('tanh') + self.fp32_list.remove('tan') + self.bf16_list.add('tan') - self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'}) + self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tan'}) def test_amp_lists_3(self): # 3. w={'lstm'}, b=None