Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] support OD level for static #53768

Merged
merged 8 commits into from
May 16, 2023

Conversation

AnnaTrainingG
Copy link
Contributor

PR types

Others

PR changes

APIs

Description

Pcard-70458

@paddle-bot
Copy link

paddle-bot bot commented May 12, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot
Copy link

paddle-bot bot commented May 12, 2023

❌ The PR is not created using PR's template. You can refer to this Demo.
Please use PR's template, it helps save our maintainers' time so that more developers get helped.

@AnnaTrainingG AnnaTrainingG changed the title [AMP] support OD level [AMP] support OD level in static May 12, 2023
@AnnaTrainingG AnnaTrainingG changed the title [AMP] support OD level in static [AMP] support OD level for static May 12, 2023
amp_lists = AutoMixedPrecisionLists(dtype=amp_dtype)

amp_lists.white_list = {"conv2d", "matmul_v2"}
amp_lists.black_list = amp_lists.all_list - amp_lists.white_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

880~881和887~888可以合并下,OD level下,直接用默认名单。

884行的warning还需要一个条件,当用户设置了amp_lists,才需要给warning提示。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@@ -116,13 +116,13 @@ def _get_sys_unsupported_list(dtype):
}
sys_unsupported_list -= supported_fp16_list

return device, sys_unsupported_list
return device, sys_unsupported_list, sys_unsupported_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为啥返回了2个sys_unsupported_list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@zhangting2020
Copy link
Contributor

增加一个单测到test/amp/test_amp_api.py中,里面目前有动态图的单测,可以加个静态图的,静态图可以参考test/amp/test_amp_promote.py用collect_operator_stats收集OP信息。

@@ -101,7 +101,7 @@ def _get_sys_unsupported_list(dtype):
device = 'NPU'
else:
device = 'GPU'
_, _, sys_unsupported_list = core.op_supported_infos(device, var_type)
_, _, sys_supported_list = core.op_supported_infos(device, var_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看op_supported_infos这个接口的返回值是all_ops、supported_ops和unsupported_ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

dtype = get_low_precision_dtypestr(dest_type)
amp_lists = AutoMixedPrecisionLists(dtype)

amp_lists.white_list = {"conv2d", "matmul_v2"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实应该是用框架默认白名单,也就是617行得到的amp_list中的白名单实际就是默认名单。需要修改的就是黑名单。

level,
)

def test_static_amp_o1(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

o1->OD

use_promote=True,
expected_op_calls=expected_fp16_calls,
)
paddle.disable_static()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测可能存在一个问题就是build_conv_model得到的program,在O1下大概也是87行的结果,可能会看不出差异。你可以把网络换成25行那个动态图的模型写法,这样能确保O1和OD是有差异的

Copy link
Contributor

@zhangting2020 zhangting2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

剩余问题下一个PR修改

@AnnaTrainingG
Copy link
Contributor Author

ok

@AnnaTrainingG AnnaTrainingG merged commit c2c3bd4 into PaddlePaddle:develop May 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants