Skip to content

Commit

Permalink
[Typing][B-33,B-37] Add type annotations for `python/paddle/amp/{amp_…
Browse files Browse the repository at this point in the history
…lists,__init__}.py` (PaddlePaddle#65633)
  • Loading branch information
enkilee authored and lixcli committed Jul 22, 2024
1 parent 81bc005 commit 1fbc397
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
5 changes: 3 additions & 2 deletions python/paddle/amp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from paddle.base import core
from paddle.base.framework import (
Expand Down Expand Up @@ -48,7 +49,7 @@
]


def is_float16_supported(device=None):
def is_float16_supported(device: str | None = None) -> bool:
"""
Determine whether the place supports float16 in the auto-mixed-precision training.
Expand All @@ -75,7 +76,7 @@ def is_float16_supported(device=None):
return core.is_float16_supported(device)


def is_bfloat16_supported(device=None):
def is_bfloat16_supported(device: str | None = None) -> bool:
"""
Determine whether the place supports bfloat16 in the auto-mixed-precision training.
Expand Down
7 changes: 5 additions & 2 deletions python/paddle/amp/amp_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

# The set of ops that support fp16 and bf16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16 or bf16.

from __future__ import annotations

WHITE_LIST = {
'conv2d',
'einsum',
Expand Down Expand Up @@ -102,7 +105,7 @@


# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
def white_list():
def white_list() -> dict[str, dict[str, set[str]]]:
white_list = {
"float16": {
"OD": FP16_WHITE_LIST,
Expand All @@ -118,7 +121,7 @@ def white_list():
return white_list


def black_list():
def black_list() -> dict[str, dict[str, set[str]]]:
black_list = {
"float16": {
"OD": set(),
Expand Down

0 comments on commit 1fbc397

Please sign in to comment.