From 13f2673caab46152ee7ba8051b8b0a9602270775 Mon Sep 17 00:00:00 2001 From: Polisetty V R K Jyothendra Varma Date: Fri, 19 Jul 2024 04:02:32 +0300 Subject: [PATCH] improve repeat_kv GQA perf (#419) --- megatron/model/transformer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index be8ae6ef4b..e79abea3cf 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -654,11 +654,16 @@ def repeat_kv(self, hidden_states, n_rep): slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, :, None, :].expand( - slen, batch, num_key_value_heads_per_partition, n_rep, head_dim) - return hidden_states.reshape(slen, batch, - num_key_value_heads_per_partition * n_rep, - head_dim) + elif num_key_value_heads_per_partition == 1: + # If no of KV heads is 1 then just perform expand operation + # instead of unsqueeze, expand and reshape to match query states. + return hidden_states.expand(slen, batch, n_rep, head_dim) + else: + hidden_states = hidden_states[:, :, :, None, :].expand( + slen, batch, num_key_value_heads_per_partition, n_rep, head_dim) + return hidden_states.reshape(slen, batch, + num_key_value_heads_per_partition * n_rep, + head_dim) def split_tensor(self, mixed_x_layer): query_layer, key_layer, value_layer = torch.split(mixed_x_layer, [self.num_key_value_groups, 1, 1], dim=-2)