-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] support OD level for static #53768
Changes from 7 commits
2b78304
88c07d4
dac6817
b651a61
c2f4f51
ea92b93
ea53699
f29804c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,7 +101,7 @@ def _get_sys_unsupported_list(dtype): | |
device = 'NPU' | ||
else: | ||
device = 'GPU' | ||
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type) | ||
_, _, sys_supported_list = core.op_supported_infos(device, var_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我看op_supported_infos这个接口的返回值是all_ops、supported_ops和unsupported_ops There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
# sys_unsupported_list will include the following ops. | ||
supported_fp16_list = { | ||
|
@@ -114,15 +114,15 @@ def _get_sys_unsupported_list(dtype): | |
"lod_array_length", | ||
"write_to_array", | ||
} | ||
sys_unsupported_list -= supported_fp16_list | ||
sys_unsupported_list = sys_supported_list - supported_fp16_list | ||
|
||
return device, sys_unsupported_list | ||
return device, sys_unsupported_list, sys_supported_list | ||
|
||
|
||
def _get_unsupported_list(dtype): | ||
# The set of ops that don't support fp16 calculation | ||
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype) | ||
return _sys_unsupported_list | ||
_, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype) | ||
return _sys_unsupported_list, _sys_all_list | ||
|
||
|
||
# The three sets listed below are changed dynamiclly. They don't contain all | ||
|
@@ -201,7 +201,9 @@ def __init__( | |
self.white_list = copy.copy(_get_white_list(self.amp_dtype)) | ||
self.black_list = copy.copy(_get_black_list()) | ||
self.gray_list = copy.copy(gray_list) | ||
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) | ||
unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype) | ||
self.unsupported_list = copy.copy(unsupported_list) | ||
self.all_list = copy.copy(sys_all_list) | ||
self.black_varnames = copy.copy(custom_black_varnames) | ||
self._update_list() | ||
|
||
|
@@ -233,7 +235,9 @@ def _update_list(self): | |
self.gray_list.remove(op_name) | ||
self.black_list.add(op_name) | ||
self.unsupported_list.add(op_name) | ||
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype) | ||
device, sys_unsupported_list, _ = _get_sys_unsupported_list( | ||
self.amp_dtype | ||
) | ||
actual_unsupported_list = [] | ||
for op_name in sys_unsupported_list: | ||
if op_name in self.white_list: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -426,7 +426,7 @@ def set_var_dst_dtype( | |
|
||
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): | ||
keep_fp32_var_names = set() | ||
if level == "O1": | ||
if level == "O1" or level == "OD": | ||
return keep_fp32_var_names | ||
all_parameters = [] | ||
for block in program.blocks: | ||
|
@@ -611,6 +611,14 @@ def cast_model_to_fp16( | |
if level == 'O2': | ||
amp_lists.black_list = amp_lists.black_list - black_list | ||
|
||
if level == 'OD': | ||
if amp_lists is not None: | ||
dtype = get_low_precision_dtypestr(dest_type) | ||
amp_lists = AutoMixedPrecisionLists(dtype) | ||
|
||
amp_lists.white_list = {"conv2d", "matmul_v2"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里其实应该是用框架默认白名单,也就是617行得到的amp_list中的白名单实际就是默认名单。需要修改的就是黑名单。 |
||
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list | ||
|
||
global_block = program.global_block() | ||
keep_fp32_ops = set() | ||
keep_fp16_ops = set() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,11 @@ | |
|
||
import unittest | ||
|
||
from amp_base_models import AmpTestBase | ||
import numpy as np | ||
from amp_base_models import AmpTestBase, build_conv_model | ||
|
||
import paddle | ||
from paddle.static import amp | ||
|
||
|
||
class TestAutoCast(AmpTestBase): | ||
|
@@ -35,6 +37,72 @@ def test_amp_OD_level(self): | |
self.assertEqual(out3.dtype, paddle.float32) | ||
|
||
|
||
class TestStaticDecorate(AmpTestBase): | ||
def check_results( | ||
self, use_amp, dtype, level, use_promote, expected_op_calls | ||
): | ||
( | ||
main_program, | ||
startup_program, | ||
optimizer, | ||
feed_vars, | ||
fetch_vars, | ||
) = build_conv_model(use_amp, dtype, level, use_promote) | ||
self.assertEqual(main_program.num_blocks, 1) | ||
optimizer = paddle.fluid.optimizer.Adadelta(learning_rate=0.001) | ||
optimizer = paddle.static.amp.decorate( | ||
optimizer, | ||
init_loss_scaling=128.0, | ||
use_dynamic_loss_scaling=True, | ||
level=level, | ||
) | ||
|
||
amp.debugging.collect_operator_stats(main_program) | ||
op_stats_list = amp.debugging._get_op_stats_list(main_program) | ||
|
||
self._check_op_calls( | ||
op_stats_list[0], expected_fp16_calls=expected_op_calls | ||
) | ||
|
||
place = paddle.CUDAPlace(0) | ||
exe = paddle.static.Executor(place) | ||
|
||
max_iters = 2 | ||
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32") | ||
losses_o1 = self.run_program( | ||
main_program, | ||
startup_program, | ||
optimizer, | ||
feed_vars, | ||
fetch_vars, | ||
place, | ||
exe, | ||
x_fp32, | ||
max_iters, | ||
level, | ||
) | ||
|
||
def test_static_amp_o1(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. o1->OD |
||
paddle.enable_static() | ||
expected_fp16_calls = { | ||
"conv2d": 1, | ||
"elementwise_add": 0, | ||
"relu": 0, | ||
"matmul_v2": 1, | ||
"softmax": 0, | ||
"reduce_mean": 0, | ||
"adamw": 0, | ||
} | ||
self.check_results( | ||
True, | ||
'float16', | ||
'OD', | ||
use_promote=True, | ||
expected_op_calls=expected_fp16_calls, | ||
) | ||
paddle.disable_static() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个单测可能存在一个问题就是build_conv_model得到的program,在O1下大概也是87行的结果,可能会看不出差异。你可以把网络换成25行那个动态图的模型写法,这样能确保O1和OD是有差异的 |
||
|
||
|
||
class TestGradScaler(AmpTestBase): | ||
def test_amp_grad_scaler(self): | ||
model = paddle.nn.Conv2D(3, 2, 3) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
880~881和887~888可以合并下,OD level下,直接用默认名单。
884行的warning还需要一个条件,当用户设置了amp_lists,才需要给warning提示。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改