Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed May 12, 2023
1 parent 72cb09e commit 2b78304
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 23 deletions.
41 changes: 23 additions & 18 deletions python/paddle/static/amp/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision:
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type
of each operator will be casted by white_list and black_list;
O2 represent Pure fp16 or bf16, all operators parameters and input
data will be casted to fp16 or bf16, except operators in black_list,
don't support fp16 or bf16 kernel and batch_norm.
level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16.
dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
Expand Down Expand Up @@ -123,6 +122,7 @@ def __init__(
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = level == "O2"
self._amp_level = level
self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling:
Expand Down Expand Up @@ -258,7 +258,7 @@ def backward(
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
level=self._amp_level,
use_promote=self.use_promote,
)

Expand Down Expand Up @@ -397,7 +397,7 @@ def run_example_code():
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
level=self._amp_level,
use_promote=self.use_promote,
)

Expand Down Expand Up @@ -800,12 +800,11 @@ def decorate(
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type of
each operator will be casted by white_list and black_list;
O2 represent pure FP16 / BF16 training, all operators parameters
and input data will be casted to FP16 / BF16, except operators in
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
Expand Down Expand Up @@ -874,15 +873,21 @@ def forward(self, x):
"""
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError("level should be O0, OD, O1 or O2.")

amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

if level == 'OD':
warnings.warn("Amp level is OD, amp list will't be used!")
if amp_lists is not None:
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list

if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16"

Expand Down
12 changes: 8 additions & 4 deletions python/paddle/static/amp/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def _get_sys_unsupported_list(dtype):
}
sys_unsupported_list -= supported_fp16_list

return device, sys_unsupported_list
return device, sys_unsupported_list, sys_unsupported_list


def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list
_, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list, _sys_all_list


# The three sets listed below are changed dynamiclly. They don't contain all
Expand Down Expand Up @@ -201,7 +201,11 @@ def __init__(
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
_, unsupported_list, sys_all_list = _get_unsupported_list(
self.amp_dtype
)
self.unsupported_list = copy.copy(unsupported_list)
self.all_list = copy.copy(sys_all_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()

Expand Down
10 changes: 9 additions & 1 deletion python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def set_var_dst_dtype(

def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
keep_fp32_var_names = set()
if level == "O1":
if level == "O1" or level == "OD":
return keep_fp32_var_names
all_parameters = []
for block in program.blocks:
Expand Down Expand Up @@ -611,6 +611,14 @@ def cast_model_to_fp16(
if level == 'O2':
amp_lists.black_list = amp_lists.black_list - black_list

if level == 'OD':
if amp_lists is not None:
dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)

amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list

global_block = program.global_block()
keep_fp32_ops = set()
keep_fp16_ops = set()
Expand Down

0 comments on commit 2b78304

Please sign in to comment.