diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e75f13a24f..be8ae6ef4b 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -1,3 +1,4 @@ +# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Transformer.""" @@ -660,9 +661,10 @@ def repeat_kv(self, hidden_states, n_rep): head_dim) def split_tensor(self, mixed_x_layer): - query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:2] + (-1, self.hidden_size_per_attention_head)) - key_layer = mixed_x_layer[:, :, :, -2, :] - value_layer = mixed_x_layer[:, :, :, -1, :] + query_layer, key_layer, value_layer = torch.split(mixed_x_layer, [self.num_key_value_groups, 1, 1], dim=-2) + query_layer = query_layer.reshape(mixed_x_layer.shape[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)) + key_layer = torch.squeeze(key_layer, -2) + value_layer = torch.squeeze(value_layer, -2) return query_layer, key_layer, value_layer