-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP] support OD level for static #53768
[AMP] support OD level for static #53768
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
❌ The PR is not created using PR's template. You can refer to this Demo. |
fa29c2f
to
2b78304
Compare
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype) | ||
|
||
amp_lists.white_list = {"conv2d", "matmul_v2"} | ||
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
880~881和887~888可以合并下,OD level下,直接用默认名单。
884行的warning还需要一个条件,当用户设置了amp_lists,才需要给warning提示。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
@@ -116,13 +116,13 @@ def _get_sys_unsupported_list(dtype): | |||
} | |||
sys_unsupported_list -= supported_fp16_list | |||
|
|||
return device, sys_unsupported_list | |||
return device, sys_unsupported_list, sys_unsupported_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为啥返回了2个sys_unsupported_list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
增加一个单测到test/amp/test_amp_api.py中,里面目前有动态图的单测,可以加个静态图的,静态图可以参考test/amp/test_amp_promote.py用collect_operator_stats收集OP信息。 |
@@ -101,7 +101,7 @@ def _get_sys_unsupported_list(dtype): | |||
device = 'NPU' | |||
else: | |||
device = 'GPU' | |||
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type) | |||
_, _, sys_supported_list = core.op_supported_infos(device, var_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看op_supported_infos这个接口的返回值是all_ops、supported_ops和unsupported_ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
dtype = get_low_precision_dtypestr(dest_type) | ||
amp_lists = AutoMixedPrecisionLists(dtype) | ||
|
||
amp_lists.white_list = {"conv2d", "matmul_v2"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里其实应该是用框架默认白名单,也就是617行得到的amp_list中的白名单实际就是默认名单。需要修改的就是黑名单。
level, | ||
) | ||
|
||
def test_static_amp_o1(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
o1->OD
use_promote=True, | ||
expected_op_calls=expected_fp16_calls, | ||
) | ||
paddle.disable_static() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个单测可能存在一个问题就是build_conv_model得到的program,在O1下大概也是87行的结果,可能会看不出差异。你可以把网络换成25行那个动态图的模型写法,这样能确保O1和OD是有差异的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
剩余问题下一个PR修改
ok |
PR types
Others
PR changes
APIs
Description
Pcard-70458