-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] support OD level for static #53768
Changes from 6 commits
2b78304
88c07d4
dac6817
b651a61
c2f4f51
ea92b93
ea53699
f29804c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,13 +116,13 @@ def _get_sys_unsupported_list(dtype): | |
} | ||
sys_unsupported_list -= supported_fp16_list | ||
|
||
return device, sys_unsupported_list | ||
return device, sys_unsupported_list, sys_unsupported_list | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里为啥返回了2个sys_unsupported_list There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已修改 |
||
|
||
|
||
def _get_unsupported_list(dtype): | ||
# The set of ops that don't support fp16 calculation | ||
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype) | ||
return _sys_unsupported_list | ||
_, _sys_unsupported_list, _sys_all_list = _get_sys_unsupported_list(dtype) | ||
return _sys_unsupported_list, _sys_all_list | ||
|
||
|
||
# The three sets listed below are changed dynamiclly. They don't contain all | ||
|
@@ -201,7 +201,9 @@ def __init__( | |
self.white_list = copy.copy(_get_white_list(self.amp_dtype)) | ||
self.black_list = copy.copy(_get_black_list()) | ||
self.gray_list = copy.copy(gray_list) | ||
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype)) | ||
unsupported_list, sys_all_list = _get_unsupported_list(self.amp_dtype) | ||
self.unsupported_list = copy.copy(unsupported_list) | ||
self.all_list = copy.copy(sys_all_list) | ||
self.black_varnames = copy.copy(custom_black_varnames) | ||
self._update_list() | ||
|
||
|
@@ -233,7 +235,9 @@ def _update_list(self): | |
self.gray_list.remove(op_name) | ||
self.black_list.add(op_name) | ||
self.unsupported_list.add(op_name) | ||
device, sys_unsupported_list = _get_sys_unsupported_list(self.amp_dtype) | ||
device, sys_unsupported_list, _ = _get_sys_unsupported_list( | ||
self.amp_dtype | ||
) | ||
actual_unsupported_list = [] | ||
for op_name in sys_unsupported_list: | ||
if op_name in self.white_list: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -426,7 +426,7 @@ def set_var_dst_dtype( | |
|
||
def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level): | ||
keep_fp32_var_names = set() | ||
if level == "O1": | ||
if level == "O1" or level == "OD": | ||
return keep_fp32_var_names | ||
all_parameters = [] | ||
for block in program.blocks: | ||
|
@@ -611,6 +611,14 @@ def cast_model_to_fp16( | |
if level == 'O2': | ||
amp_lists.black_list = amp_lists.black_list - black_list | ||
|
||
if level == 'OD': | ||
if amp_lists is not None: | ||
dtype = get_low_precision_dtypestr(dest_type) | ||
amp_lists = AutoMixedPrecisionLists(dtype) | ||
|
||
amp_lists.white_list = {"conv2d", "matmul_v2"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里其实应该是用框架默认白名单,也就是617行得到的amp_list中的白名单实际就是默认名单。需要修改的就是黑名单。 |
||
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list | ||
|
||
global_block = program.global_block() | ||
keep_fp32_ops = set() | ||
keep_fp16_ops = set() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
880~881和887~888可以合并下,OD level下,直接用默认名单。
884行的warning还需要一个条件,当用户设置了amp_lists,才需要给warning提示。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改