Skip to content

Commit

Permalink
improve repeat_kv GQA perf (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
polisettyvarma authored Jul 19, 2024
1 parent 7d23e33 commit 13f2673
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,11 +654,16 @@ def repeat_kv(self, hidden_states, n_rep):
slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].expand(
slen, batch, num_key_value_heads_per_partition, n_rep, head_dim)
return hidden_states.reshape(slen, batch,
num_key_value_heads_per_partition * n_rep,
head_dim)
elif num_key_value_heads_per_partition == 1:
# If no of KV heads is 1 then just perform expand operation
# instead of unsqueeze, expand and reshape to match query states.
return hidden_states.expand(slen, batch, n_rep, head_dim)
else:
hidden_states = hidden_states[:, :, :, None, :].expand(
slen, batch, num_key_value_heads_per_partition, n_rep, head_dim)
return hidden_states.reshape(slen, batch,
num_key_value_heads_per_partition * n_rep,
head_dim)

def split_tensor(self, mixed_x_layer):
query_layer, key_layer, value_layer = torch.split(mixed_x_layer, [self.num_key_value_groups, 1, 1], dim=-2)
Expand Down

0 comments on commit 13f2673

Please sign in to comment.