Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] support OD level for static #53768

Merged
merged 8 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 25 additions & 19 deletions python/paddle/static/amp/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ class OptimizerWithMixedPrecision:
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
level(str): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type
of each operator will be casted by white_list and black_list;
O2 represent Pure fp16 or bf16, all operators parameters and input
data will be casted to fp16 or bf16, except operators in black_list,
don't support fp16 or bf16 kernel and batch_norm.
level(str): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16.
dtype(str): Whether to use 'float16' or 'bfloat16'.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
Expand Down Expand Up @@ -123,6 +122,7 @@ def __init__(
self._learning_rate = optimizer._learning_rate
self._learning_rate_map = optimizer._learning_rate_map
self._use_pure_fp16 = level == "O2"
self._amp_level = level
self._use_fp16_guard = use_amp_guard
self._to_fp16_var_names = None
if self._use_dynamic_loss_scaling:
Expand Down Expand Up @@ -258,7 +258,7 @@ def backward(
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
level=self._amp_level,
use_promote=self.use_promote,
)

Expand Down Expand Up @@ -397,7 +397,7 @@ def run_example_code():
self._amp_lists,
use_fp16_guard=False,
dest_type=self._amp_vartype,
level='O1',
level=self._amp_level,
use_promote=self.use_promote,
)

Expand Down Expand Up @@ -800,12 +800,11 @@ def decorate(
amp_lists(CustomOpLists, optional): An CustomOpLists object. The default
white_list and black_list will be used for AMP training when it is
not set. Default is None.
level(str, optional): Auto mixed precision level. Accepted values are
"O1" and "O2": O1 represent mixed precision, the input data type of
each operator will be casted by white_list and black_list;
O2 represent pure FP16 / BF16 training, all operators parameters
and input data will be casted to FP16 / BF16, except operators in
black_list, don't support FP16 / BF16 kernel and batch_norm. Default is O1.
level(str, optional): Auto mixed precision level. Accepted values are "O1", "O2" and "OD": At the O1 level, operators in the white list
will use float16/bfloat16 inputs for calculations, and operators in the black list will use float32 inputs for calculations. At the O2
level, model's parameters will be casted to float16/bfloat16 by using `decorator`, and operators that have all float16/bfloat16 inputs
will be converted to float16/bfloat16, and that have any float32 input will be converted to float32. For the OD level, operators in
default white list will compute in float16/bfloat16, and the others will compute in float32. Default is O1.
dtype(str, optional): Whether to use 'float16' or 'bfloat16'. Default is 'float16'.
master_weight(bool, optinal): For level='O2', whether to use multi-precision
during weight updating. If master_weight is None, in O2 level optimizer
Expand Down Expand Up @@ -874,15 +873,22 @@ def forward(self, x):
"""
# check amp_level: O0-O2
level = level.upper()
if not (level in ['O0', 'O1', 'O2']):
raise ValueError(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
if not (level in ['O0', 'OD', 'O1', 'O2']):
raise ValueError("level should be O0, OD, O1 or O2.")

amp_dtype = check_amp_dtype(dtype)
if amp_lists is None:
if amp_lists is None or level == 'OD':
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

if level == 'OD':
if amp_lists is not None:
warnings.warn(
"If the Amp level is set to OD, the amp list will not be used."
)

amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

880~881和887~888可以合并下,OD level下,直接用默认名单。

884行的warning还需要一个条件,当用户设置了amp_lists,才需要给warning提示。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


if use_dynamic_loss_scaling is None:
use_dynamic_loss_scaling = dtype == "float16"

Expand Down
18 changes: 11 additions & 7 deletions python/paddle/static/amp/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _get_sys_unsupported_list(dtype):
device = 'NPU'
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
_, _, sys_supported_list = core.op_supported_infos(device, var_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看op_supported_infos这个接口的返回值是all_ops、supported_ops和unsupported_ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


# sys_unsupported_list will include the following ops.
supported_fp16_list = {
Expand All @@ -114,15 +114,15 @@ def _get_sys_unsupported_list(dtype):
"lod_array_length",
"write_to_array",
}
sys_unsupported_list -= supported_fp16_list
sys_unsupported_list = sys_supported_list - supported_fp16_list

return device, sys_unsupported_list
return device, sys_unsupported_list, sys_supported_list


def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list
_, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype)
return _sys_unsupported_list, _sys_all_list


# The three sets listed below are changed dynamiclly. They don't contain all
Expand Down Expand Up @@ -201,7 +201,9 @@ def __init__(
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype)
self.unsupported_list = copy.copy(unsupported_list)
self.all_list = copy.copy(sys_all_list)
self.black_varnames = copy.copy(custom_black_varnames)
self._update_list()

Expand Down Expand Up @@ -233,7 +235,9 @@ def _update_list(self):
self.gray_list.remove(op_name)
self.black_list.add(op_name)
self.unsupported_list.add(op_name)
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype)
device, sys_unsupported_list, _ = _get_sys_unsupported_list(
self.amp_dtype
)
actual_unsupported_list = []
for op_name in sys_unsupported_list:
if op_name in self.white_list:
Expand Down
10 changes: 9 additions & 1 deletion python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def set_var_dst_dtype(

def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
keep_fp32_var_names = set()
if level == "O1":
if level == "O1" or level == "OD":
return keep_fp32_var_names
all_parameters = []
for block in program.blocks:
Expand Down Expand Up @@ -611,6 +611,14 @@ def cast_model_to_fp16(
if level == 'O2':
amp_lists.black_list = amp_lists.black_list - black_list

if level == 'OD':
if amp_lists is not None:
dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)

amp_lists.white_list = {"conv2d", "matmul_v2"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实应该是用框架默认白名单,也就是617行得到的amp_list中的白名单实际就是默认名单。需要修改的就是黑名单。

amp_lists.black_list = amp_lists.all_list - amp_lists.white_list

global_block = program.global_block()
keep_fp32_ops = set()
keep_fp16_ops = set()
Expand Down
70 changes: 69 additions & 1 deletion test/amp/test_amp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

import unittest

from amp_base_models import AmpTestBase
import numpy as np
from amp_base_models import AmpTestBase, build_conv_model

import paddle
from paddle.static import amp


class TestAutoCast(AmpTestBase):
Expand All @@ -35,6 +37,72 @@ def test_amp_OD_level(self):
self.assertEqual(out3.dtype, paddle.float32)


class TestStaticDecorate(AmpTestBase):
def check_results(
self, use_amp, dtype, level, use_promote, expected_op_calls
):
(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
) = build_conv_model(use_amp, dtype, level, use_promote)
self.assertEqual(main_program.num_blocks, 1)
optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
level=level,
)

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)

self._check_op_calls(
op_stats_list[0], expected_fp16_calls=expected_op_calls
)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)

max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
losses_o1 = self.run_program(
main_program,
startup_program,
optimizer,
feed_vars,
fetch_vars,
place,
exe,
x_fp32,
max_iters,
level,
)

def test_static_amp_o1(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o1->OD

paddle.enable_static()
expected_fp16_calls = {
"conv2d": 1,
"elementwise_add": 0,
"relu": 0,
"matmul_v2": 1,
"softmax": 0,
"reduce_mean": 0,
"adamw": 0,
}
self.check_results(
True,
'float16',
'OD',
use_promote=True,
expected_op_calls=expected_fp16_calls,
)
paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测可能存在一个问题就是build_conv_model得到的program,在O1下大概也是87行的结果,可能会看不出差异。你可以把网络换成25行那个动态图的模型写法,这样能确保O1和OD是有差异的



class TestGradScaler(AmpTestBase):
def test_amp_grad_scaler(self):
model = paddle.nn.Conv2D(3, 2, 3)
Expand Down