Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use split/squeeze instead of slice for performance #409

Merged
merged 2 commits into from
Jul 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Transformer."""
Expand Down Expand Up @@ -660,9 +661,10 @@ def repeat_kv(self, hidden_states, n_rep):
head_dim)

def split_tensor(self, mixed_x_layer):
query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:2] + (-1, self.hidden_size_per_attention_head))
key_layer = mixed_x_layer[:, :, :, -2, :]
value_layer = mixed_x_layer[:, :, :, -1, :]
query_layer, key_layer, value_layer = torch.split(mixed_x_layer, [self.num_key_value_groups, 1, 1], dim=-2)
query_layer = query_layer.reshape(mixed_x_layer.shape[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
key_layer = torch.squeeze(key_layer, -2)
value_layer = torch.squeeze(value_layer, -2)

return query_layer, key_layer, value_layer

Expand Down