Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-33,B-37] Add type annotations for python/paddle/amp/{amp_lists,__init__}.py #65633

Merged
merged 4 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/paddle/amp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from paddle.base import core
from paddle.base.framework import (
Expand Down Expand Up @@ -48,7 +49,7 @@
]


def is_float16_supported(device=None):
def is_float16_supported(device: str | None = None) -> bool:
"""
Determine whether the place supports float16 in the auto-mixed-precision training.

Expand All @@ -75,7 +76,7 @@ def is_float16_supported(device=None):
return core.is_float16_supported(device)


def is_bfloat16_supported(device=None):
def is_bfloat16_supported(device: str | None = None) -> bool:
"""
Determine whether the place supports bfloat16 in the auto-mixed-precision training.

Expand Down
7 changes: 5 additions & 2 deletions python/paddle/amp/amp_lists.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩都不是公开 API,这个文件的统计是统计错了还是漏标记公开 API 了?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这俩是我这边手动添加的 ~

我这里统计的依据基本是遵循 __all__ 的范围,有一些虽然没在里面,但是在 __init__.py 里面导进去了,我看挺简单的也就放进去了 ~

Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

# The set of ops that support fp16 and bf16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16 or bf16.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP 563

from typing import Any

WHITE_LIST = {
'conv2d',
'einsum',
Expand Down Expand Up @@ -102,7 +105,7 @@


# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
def white_list():
def white_list() -> dict[str, dict[str, set[str]]]:
white_list = {
"float16": {
"OD": FP16_WHITE_LIST,
Expand All @@ -118,7 +121,7 @@ def white_list():
return white_list


def black_list():
def black_list() -> dict[str, dict[str, Any]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def black_list() -> dict[str, dict[str, Any]]:
def black_list() -> dict[str, dict[str, set[str]]]:

应该也还是 set[str] 的样子 ~

另外,@SigureMo 这里有必要用 TypedDict 吗?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用了吧,这又不是公开 API,这里的输入输出的「协议」让开发的人自己把握就好了,改成 TypedDict 他们反而维护不起来了

black_list = {
"float16": {
"OD": set(),
Expand Down