From 309d3f0c5e8525ad59047995610eb7374b6f246e Mon Sep 17 00:00:00 2001 From: Jinghan Yao Date: Sat, 3 Aug 2024 00:09:20 -0400 Subject: [PATCH] pass batch_dim_idx to deepspeed sequence parallel distributed attention for supporting batch size larger than 1 --- megatron/model/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 7467190582..036c11566a 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -817,12 +817,14 @@ def forward(self, hidden_states, attention_mask, # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) if self.enable_ds_sequence_parallel: + batch_dim_idx = 1 if self.use_flash_attn: if not self.use_flash_attn_triton: query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] + batch_dim_idx = 0 - context_layer = self.dist_attn(query_layer, key_layer, value_layer) + context_layer = self.dist_attn(query_layer, key_layer, value_layer, batch_dim_idx) if not self.use_flash_attn_triton: context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()