-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix potential random layout inconsistency issues in sparse attention modules #534
Changes from 1 commit
fb28742
2509bad
2a3e0a8
e087d99
da2dd1e
fbdf2bf
8bed417
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import torch.nn as nn | ||
from torch.nn.functional import * | ||
import torch | ||
from torch import distributed as dist | ||
from collections import namedtuple | ||
from deepspeed.ops.sparse_attention import MatMul, Softmax, SparsityConfig | ||
import sys | ||
|
@@ -22,7 +23,8 @@ def __init__( | |
# SparsityConfig parameters needs to be set accordingly | ||
sparsity_config=SparsityConfig(num_heads=4), | ||
key_padding_mask_mode='add', | ||
attn_mask_mode='mul'): | ||
attn_mask_mode='mul', | ||
max_seq_length=2048): | ||
"""Initialize the sparse self attention layer. | ||
Arguments: | ||
sparsity_config: optional: this parameter determins sparsity pattern configuration; it is based on SparsityConfig class. | ||
|
@@ -34,17 +36,36 @@ def __init__( | |
# sparsity information | ||
self.sparsity_config = sparsity_config | ||
|
||
# initialize sparse layout and register as buffer | ||
master_layout = self.sparsity_config.make_layout(max_seq_length) | ||
self.register_buffer("master_layout", master_layout) | ||
self._need_layout_synchronization = True | ||
|
||
# mask modes | ||
self.key_padding_mask_mode = key_padding_mask_mode | ||
self.attn_mask_mode = attn_mask_mode | ||
|
||
ops = dict() | ||
|
||
def get_layout(self, L): | ||
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0 | ||
if self._need_layout_synchronization and dist.is_initialized(): | ||
dist.broadcast(self.master_layout, src=0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might break with model parallelism (e.g., megatron-style or pipeline parallelism). However, it might be tricky to get the correct process group and rank inside the op since we can't easily communicate with the deepspeed engine to get this info here. /cc @ShadenSmith, @samyam There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point @jeffra . I think we want to only broadcast along the data parallel group, similar to our weight initialization? But getting the group is tricky as you pointed out. We could add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep that makes sense. The training engine hasn't been created yet in this timeline, so that's a bit tricky. However, for now let's just the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have multiple data_parallel_groups (i.e. in model parallel scenario), does that mean we would also require passing in the source rank to broadcast from within that process group? Do you think we would also need an optional argument for broadcast_src_rank in the constructor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, we could also add the broadcast_src_rank parameter. This just means the caller has to do this translation instead of us, which sounds fine. |
||
self._need_layout_synchronization = False | ||
|
||
if (L % self.sparsity_config.block != 0): | ||
raise ValueError( | ||
f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!' | ||
) | ||
|
||
num_blocks = L // self.sparsity_config.block | ||
return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor | ||
|
||
# add to cache | ||
def get_ops(self, H, L): | ||
import sys | ||
if L not in SparseSelfAttention.ops: | ||
sparsity_layout = self.sparsity_config.make_layout(L) | ||
sparsity_layout = self.get_layout(L) | ||
sparse_dot_sdd_nt = MatMul(sparsity_layout, | ||
self.sparsity_config.block, | ||
'sdd', | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add docstring for the new parameter as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, just added.