-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][C-40][BUAA] Add type annotations for 1 file in python/paddle/distributed/fleet/base/distributed_strategy.py
#67405
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
请求review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没有全部 review 完,问题比较多,主要集中在 TypedDict
的使用上,可以看一下怎么用这个类型标注 ~ 另外,即使是 dict
,也需要两个参数,如 dict[x,y]
而不是 dict[x]
。
还有,需要标注 DistributedStrategy
中的实例属性,如 strategy
,参考 #67448 (review)
最后,PR 的描述不规范,看一下 https://github.com/PaddlePaddle/Paddle/issues/65008
是如何要求的
@@ -299,7 +307,7 @@ def build_strategy(self): | |||
|
|||
@build_strategy.setter | |||
@is_strict_auto | |||
def build_strategy(self, strategy): | |||
def build_strategy(self, strategy: str) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def build_strategy(self, strategy: str) -> None: | |
def build_strategy(self, strategy: BuildStrategy) -> None: |
@@ -313,7 +321,7 @@ def build_strategy(self, strategy): | |||
) | |||
|
|||
@property | |||
def gradient_scale_configs(self): | |||
def gradient_scale_configs(self) -> dict[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def gradient_scale_configs(self) -> dict[str]: | |
def gradient_scale_configs(self) -> dict[str, Any]: |
@@ -332,7 +340,7 @@ def gradient_scale_configs(self): | |||
|
|||
@gradient_scale_configs.setter | |||
@is_strict_auto | |||
def gradient_scale_configs(self, config): | |||
def gradient_scale_configs(self, config: dict[str]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def gradient_scale_configs(self, config: dict[str]) -> None: | |
def gradient_scale_configs(self, config: dict[str, Any]) -> None: |
@@ -367,7 +375,7 @@ def a_sync(self): | |||
|
|||
@a_sync.setter | |||
@is_strict_auto | |||
def a_sync(self, flag): | |||
def a_sync(self, flag: bool) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def a_sync(self):
没标注
@@ -377,7 +385,7 @@ def a_sync(self, flag): | |||
) | |||
|
|||
@property | |||
def a_sync_configs(self): | |||
def a_sync_configs(self) -> dict[int | bool]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用 TypedDict
assert isinstance(flag, bool), "qat should have value of bool type" | ||
self.strategy.qat = flag | ||
|
||
@property | ||
def qat_configs(self): | ||
def qat_configs(self) -> dict[bool | int | str | list[str]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypedDict
@@ -1085,12 +1103,12 @@ def qat_configs(self): | |||
return get_msg_dict(self.strategy.qat_configs) | |||
|
|||
@qat_configs.setter | |||
def qat_configs(self, configs): | |||
def qat_configs(self, configs: dict[bool | int | str | list[str]]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
if isinstance(flag, bool): | ||
self.strategy.recompute = flag | ||
else: | ||
logger.warning("recompute should have value of bool type") | ||
|
||
@property | ||
def sync_nccl_allreduce(self): | ||
def sync_nccl_allreduce(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def sync_nccl_allreduce(self) -> None: | |
def sync_nccl_allreduce(self) -> bool: |
if isinstance(value, int): | ||
self.strategy.fuse_grad_size_in_MB = value | ||
else: | ||
logger.warning("fuse_grad_size_in_MB should have value of int type") | ||
|
||
@property | ||
def last_comm_group_size_MB(self): | ||
def last_comm_group_size_MB(self) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def last_comm_group_size_MB(self) -> float: | |
def last_comm_group_size_MB(self) -> int: |
@@ -1295,14 +1313,14 @@ def last_comm_group_size_MB(self): | |||
|
|||
@last_comm_group_size_MB.setter | |||
@is_strict_auto | |||
def last_comm_group_size_MB(self, value): | |||
def last_comm_group_size_MB(self, value: float) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def last_comm_group_size_MB(self, value: float) -> None: | |
def last_comm_group_size_MB(self, value: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请问TypedDict是需要用class定义吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关联 #65008 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里涉及较多 protobuf
与私有的一些成员,比较麻烦,辛苦 ~
p.s. 另外,函数内部的函数,一般不需要标注。
check_configs_key( | ||
self.strategy.fs_client_param, configs, "fs_client_param" | ||
) | ||
assign_configs_value(self.strategy.fs_client_param, configs) | ||
|
||
@property | ||
def sparse_table_configs(self): | ||
def sparse_table_configs(self) -> SparseTableConf: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SparseTableConf
的地方改为 dict[str, Any]
吧 ~
首先,根据代码中的逻辑,
for table_name in configs:
...
configs
应该是 dict[str, dict[str, Any]]
的形式,是个嵌套的 dict
;其次,嵌套的 dict
不只包括列表中的东西,如 if not configs.get("use_cvm", True):
~ 而且,这个表太大了,后面维护会很困难,所以建议用 dict[str, Any]
替代 ~
@@ -772,7 +969,9 @@ def sparse_optimizer_config(sgd, strategy, prefix): | |||
) | |||
sgd.adam.weight_bounds.extend(bounds) | |||
|
|||
def set_sparse_table_config(table_data, config): | |||
def set_sparse_table_config( | |||
table_data: Table, config: dict[float] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
几个问题:
dict
需要两个参数- 这里不是
Table
- 内部函数一般不需要标注
这里涉及到 distributed
模块内部的一些 protobuf
成员,不标注吧 ~
def sparse_optimizer_config( | ||
sgd: Table, strategy: DistributedStrategy, prefix: str | ||
) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不标注
@@ -189,38 +368,50 @@ def __init__(self): | |||
DistributedStrategy supports configurations from BuildStrategy. | |||
|
|||
""" | |||
self.strategy = distributed_strategy_pb2.DistributedStrategy() | |||
self.strategy: DistributedStrategy = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.strategy: DistributedStrategy = ( | |
self.strategy: Any = ( |
这里的 strategy
与 DistributedStrategy
不是同一个东西,distributed_strategy_pb2.DistributedStrategy
是 protobuf
生成的,这里直接用 Any
表示吧 ~
from paddle.static import BuildStrategy | ||
|
||
|
||
class SyncConf(TypedDict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一部分 TypedDict
的定义:
- 放到
TYPE_CHECKING
内部 - 使用
_
开头,表示内部使用,如_SyncConf
- 统一改为
total=False
的形式,如class _SyncConf(TypedDict, total=False)
- 统一不加
| None
,除非文档中明确表示可以用None
- 不添加
SparseTableConf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ~
PR Category
User Experience
PR Types
Improvements
Description
类型标注:
Related links
@SigureMo @megemini