-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][C-18,C-68,C-77] Add type annotations for distributed/communication/group.py
and incubate/nn/functional/{block_multihead_attention,fused_transformer}.py
#67677
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -154,7 +167,7 @@ def is_initialized(): | |||
return _GroupManager.global_group_id in _GroupManager.group_map_by_id | |||
|
|||
|
|||
def destroy_process_group(group=None): | |||
def destroy_process_group(group=None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def destroy_process_group(group=None) -> None: | |
def destroy_process_group(group: Group | None = None) -> None: |
if self.rank < 0: | ||
return False | ||
if self.nranks < 2: | ||
return False | ||
return True | ||
|
||
def get_group_rank(self, rank): | ||
def get_group_rank(self, rank) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_group_rank(self, rank) -> int: | |
def get_group_rank(self, rank: int) -> int | Literal[-1]: |
@@ -196,7 +209,7 @@ def destroy_process_group(group=None): | |||
del _GroupManager.group_map_by_id[group.id] | |||
|
|||
|
|||
def get_group(id=0): | |||
def get_group(id=0) -> Group: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_group(id=0) -> Group: | |
def get_group(id: int=0) -> Group: |
@@ -255,7 +268,7 @@ def _sync_comm_stream(tensor, ring_id=0): | |||
) | |||
|
|||
|
|||
def wait(tensor, group=None, use_calc_stream=True): | |||
def wait(tensor, group=None, use_calc_stream=True) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def wait(tensor, group=None, use_calc_stream=True) -> None: | |
def wait(tensor: Tensor, group: Group | None=None, use_calc_stream:bool=True) -> None: |
@@ -291,7 +304,7 @@ def wait(tensor, group=None, use_calc_stream=True): | |||
_sync_comm_stream(tensor, ring_id) | |||
|
|||
|
|||
def barrier(group=None): | |||
def barrier(group=None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def barrier(group=None) -> None: | |
def barrier(group:Group|None=None) -> None: |
_activation_function = Literal["relu", "gelu"] | ||
_mode = Literal["upscale_in_train", "downscale_in_infer"] | ||
_norm_type = Literal["layernorm"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CamelCase 方式, _mode
改为 _Mode
ln2_bias: Tensor | None = None, | ||
dropout1_rate: float = 0.5, | ||
dropout2_rate: float = 0.5, | ||
activation: _activation_function = "relu", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么只能是 "relu", "gelu" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
抱歉这个武断了,看fused_linear_activation
有这个限制,以为就都有这个限制了
ffn2_biases: Sequence[Tensor], | ||
pre_layer_norm: bool = True, | ||
epsilon: float = 1e-05, | ||
residual_alpha: int = 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
residual_alpha: int = 1.0, | |
residual_alpha: float = 1.0, |
mode: _mode = 'upscale_in_train', | ||
trans_qkvw: bool = True, | ||
ring_id: int = -1, | ||
norm_type: _norm_type = "layernorm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么只能是 "layernorm ?
use_neox_rotary_style: bool = False, | ||
gqa_group_size: int = -1, | ||
name: str | None = None, | ||
) -> Tensor | tuple(Tensor, Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要用 overload
,如果 cache_kvs
是 None
则输出 Tensor
,否则输出 tuple[Tensor, Sequence[Tensor]]
用 关联 #65008 |
if TYPE_CHECKING: | ||
from paddle import Tensor | ||
|
||
_quant_round_type = Literal[0, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_quant_round_type = Literal[0, 1] | |
_QuantRoundType: TypeAlias = Literal[0, 1] |
@@ -12,6 +12,10 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
from __future__ import annotations | |||
|
|||
from typing import TYPE_CHECKING, Literal, Sequence, overload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
从 collections.abc
导入 Sequence
ffn1_biases: Sequence[Tensor], | ||
ffn2_weights: Sequence[Tensor], | ||
ffn2_biases: Sequence[Tensor], | ||
pre_layer_norm: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre_layer_norm: bool = True, | |
pre_layer_norm: bool = ..., |
注意 overload
中默认值的写法
use_neox_rotary_style: bool = False, | ||
gqa_group_size: int = -1, | ||
name: str | None = None, | ||
) -> tuple(Tensor, Sequence[Tensor]): ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> tuple(Tensor, Sequence[Tensor]): ... | |
) -> tuple[Tensor, Sequence[Tensor]]: ... |
name: str | None = None, | ||
) -> tuple(Tensor, Sequence[Tensor]): ... | ||
|
||
|
||
def fused_multi_transformer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原函数保持不变,不要动默认值 ~ overload
中写默认值是没有意义的,运行时获取不到 ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TypeAlias
从 typing_extensions
导入
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ~
distributed/communication/group.py
and incubate/nn/functional/{block_multihead_attention,fused_transformer}.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Category
User Experience
PR Types
Improvements
Description
ProcessGroup
是从这里找到的,不知道对不对fused_multi_transformer
中,未找到norm_type
的取值范围,就先只允许layernorm了compute_dtype
也不知道取值范围