Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-18,C-68,C-77] Add type annotations for distributed/communication/group.py and incubate/nn/functional/{block_multihead_attention,fused_transformer}.py #67677

Merged
merged 5 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions python/paddle/distributed/communication/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import paddle
import paddle.distributed as dist
from paddle import framework

if TYPE_CHECKING:
from paddle import Tensor
from paddle.base.core import ProcessGroup


class Group:
"""
The abstract representation of group.
"""

def __init__(self, rank_in_group, id, ranks, pg=None, name=None):
def __init__(
self,
rank_in_group: int,
id: int,
ranks: list[int],
pg: ProcessGroup | None = None,
name: str | None = None,
) -> None:
self._rank_in_group = rank_in_group
self._world_size = len(ranks) if rank_in_group >= 0 else -1
self._id = id
Expand All @@ -33,51 +47,51 @@ def __init__(self, rank_in_group, id, ranks, pg=None, name=None):
self._name = name

@property
def rank(self):
def rank(self) -> int:
return self._rank_in_group

@property
def ranks(self):
def ranks(self) -> list[int]:
return self._ranks

@property
def nranks(self):
def nranks(self) -> int:
return len(self._ranks)

@property
def name(self):
def name(self) -> str | None:
return self._name

@property
def process_group(self):
def process_group(self) -> ProcessGroup:
return self._pg

@property
def world_size(self):
def world_size(self) -> int:
return self._world_size

@property
def backend(self):
def backend(self) -> str:
return self._pg.name()

@property
def id(self):
def id(self) -> int:
return self._id

def is_member(self):
def is_member(self) -> bool:
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True

def get_group_rank(self, rank):
def get_group_rank(self, rank: int) -> int | Literal[-1]:
if self.is_member():
return self.ranks.index(rank)
else:
return -1

def __repr__(self):
def __repr__(self) -> str:
debug_str = (
f"rank: {self.rank}, nranks: {self.nranks}, id: {self.id}, ranks: "
)
Expand Down Expand Up @@ -126,7 +140,7 @@ def _get_or_throw_group_rank(global_rank, group):
return group_rank


def is_initialized():
def is_initialized() -> bool:
"""

Check whether the distributed environment has been initialized
Expand Down Expand Up @@ -154,7 +168,7 @@ def is_initialized():
return _GroupManager.global_group_id in _GroupManager.group_map_by_id


def destroy_process_group(group=None):
def destroy_process_group(group: Group | None = None) -> None:
"""
Destroy a given group for communication

Expand Down Expand Up @@ -196,7 +210,7 @@ def destroy_process_group(group=None):
del _GroupManager.group_map_by_id[group.id]


def get_group(id=0):
def get_group(id: int = 0) -> Group:
"""

Get group instance by group id.
Expand Down Expand Up @@ -255,7 +269,9 @@ def _sync_comm_stream(tensor, ring_id=0):
)


def wait(tensor, group=None, use_calc_stream=True):
def wait(
tensor: Tensor, group: Group | None = None, use_calc_stream: bool = True
) -> None:
"""

wait to sync stream for group.
Expand Down Expand Up @@ -291,7 +307,7 @@ def wait(tensor, group=None, use_calc_stream=True):
_sync_comm_stream(tensor, ring_id)


def barrier(group=None):
def barrier(group: Group | None = None) -> None:
"""

Barrier among all participators in the group.
Expand Down Expand Up @@ -347,7 +363,7 @@ def barrier(group=None):
)


def get_backend(group=None):
def get_backend(group: Group | None = None) -> str:
"""
Get the backend of given group.

Expand Down
158 changes: 84 additions & 74 deletions python/paddle/incubate/nn/functional/block_multihead_attention.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的示例代码需要 Flash Attention ,示例中跳过吧,加 >>> # doctest: +SKIP('Need compile flash attention')

如果不清楚怎么加,可以搜一下 ~

Original file line number Diff line number Diff line change
Expand Up @@ -12,47 +12,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Literal, TypeAlias

from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode

if TYPE_CHECKING:
from paddle import Tensor

_QuantRoundType: TypeAlias = Literal[0, 1]


def block_multihead_attention(
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
block_tables,
pre_key_cache=None,
pre_value_cache=None,
cache_k_quant_scales=None,
cache_v_quant_scales=None,
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
qkv_out_scale=None,
qkv_bias=None,
out_shift=None,
out_smooth=None,
max_enc_len_this_time=None,
max_dec_len_this_time=None,
rope_emb=None,
mask=None,
tgt_mask=None,
max_seq_len=-1,
block_size=64,
use_neox_style=False,
use_dynamic_cachekv_quant=False,
quant_round_type=1,
quant_max_bound=127.0,
quant_min_bound=-127.0,
out_scale=-1,
compute_dtype="default",
):
qkv: Tensor,
key_cache: Tensor,
value_cache: Tensor,
seq_lens_encoder: Tensor,
seq_lens_decoder: Tensor,
seq_lens_this_time: Tensor,
padding_offsets: Tensor,
cum_offsets: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
block_tables: Tensor,
pre_key_cache: Tensor | None = None,
pre_value_cache: Tensor | None = None,
cache_k_quant_scales: Tensor | None = None,
cache_v_quant_scales: Tensor | None = None,
cache_k_dequant_scales: Tensor | None = None,
cache_v_dequant_scales: Tensor | None = None,
qkv_out_scale: Tensor | None = None,
qkv_bias: Tensor | None = None,
out_shift: Tensor | None = None,
out_smooth: Tensor | None = None,
max_enc_len_this_time: Tensor | None = None,
max_dec_len_this_time: Tensor | None = None,
rope_emb: Tensor | None = None,
mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
max_seq_len: int = -1,
block_size: int = 64,
use_neox_style: bool = False,
use_dynamic_cachekv_quant: bool = False,
quant_round_type: _QuantRoundType = 1,
quant_max_bound: float = 127.0,
quant_min_bound: float = -127.0,
out_scale: float = -1,
compute_dtype: str = "default",
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""
Block Multi-head attention for text summarization.

Expand Down Expand Up @@ -99,6 +108,7 @@ def block_multihead_attention(
Examples:
.. code-block:: python

>>> # doctest: +SKIP('Need compile flash attention')
>>> # doctest: +REQUIRES(env:GPU)
>>> import numpy as np
>>> import paddle
Expand Down Expand Up @@ -392,44 +402,44 @@ def block_multihead_attention(


def block_multihead_attention_xpu(
qkv,
key_cache,
value_cache,
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
padding_offsets,
cum_offsets,
cu_seqlens_q,
cu_seqlens_k,
block_tables,
cache_k_per_batch_maxs,
cache_v_per_batch_maxs,
pre_key_cache=None,
pre_value_cache=None,
cache_k_quant_scales=None,
cache_v_quant_scales=None,
cache_k_dequant_scales=None,
cache_v_dequant_scales=None,
qkv_out_scale=None,
qkv_bias=None,
out_shift=None,
out_smooth=None,
max_enc_len_this_time=None,
max_dec_len_this_time=None,
rope_emb=None,
mask=None,
tgt_mask=None,
max_seq_len=-1,
block_size=64,
use_neox_style=False,
use_dynamic_cachekv_quant=False,
quant_round_type=1,
quant_max_bound=127.0,
quant_min_bound=-127.0,
out_scale=-1,
compute_dtype="default",
):
qkv: Tensor,
key_cache: Tensor,
value_cache: Tensor,
seq_lens_encoder: Tensor,
seq_lens_decoder: Tensor,
seq_lens_this_time: Tensor,
padding_offsets: Tensor,
cum_offsets: Tensor,
cu_seqlens_q: Tensor,
cu_seqlens_k: Tensor,
block_tables: Tensor,
cache_k_per_batch_maxs: Tensor,
cache_v_per_batch_maxs: Tensor,
pre_key_cache: Tensor | None = None,
pre_value_cache: Tensor | None = None,
cache_k_quant_scales: Tensor | None = None,
cache_v_quant_scales: Tensor | None = None,
cache_k_dequant_scales: Tensor | None = None,
cache_v_dequant_scales: Tensor | None = None,
qkv_out_scale: Tensor | None = None,
qkv_bias: Tensor | None = None,
out_shift: Tensor | None = None,
out_smooth: Tensor | None = None,
max_enc_len_this_time: Tensor | None = None,
max_dec_len_this_time: Tensor | None = None,
rope_emb: Tensor | None = None,
mask: Tensor | None = None,
tgt_mask: Tensor | None = None,
max_seq_len: int = -1,
block_size: int = 64,
use_neox_style: bool = False,
use_dynamic_cachekv_quant: bool = False,
quant_round_type: _QuantRoundType = 1,
quant_max_bound: float = 127.0,
quant_min_bound: float = -127.0,
out_scale: float = -1,
compute_dtype: str = "default",
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
if in_dynamic_mode():
return _C_ops.block_multihead_attention_xpu(
qkv,
Expand Down
Loading