-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial MHA native implementation (#51)
* Initial MHA native implementation * Test for comparison between native and torch MHA update * Fixed norm to match torch's implementation * Removed the inversion of the mask * Formatting fixes * Added docstrings * Renaming MultiheadAttention to MultiheadSelfAttention * Renaming mha.py to mhsa.py --------- Co-authored-by: Benedikt Hilmes <[email protected]>
- Loading branch information
1 parent
56bf9fa
commit ef22941
Showing
2 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
__all__ = ["MultiheadSelfAttentionV1", "MultiheadSelfAttentionV1Config"] | ||
|
||
import math | ||
from dataclasses import dataclass | ||
import torch | ||
|
||
from i6_models.config import ModelConfiguration | ||
from i6_models.util import compat | ||
|
||
|
||
@dataclass | ||
class MultiheadSelfAttentionV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads` | ||
num_att_heads: number of attention heads | ||
att_weights_dropout: attention weights dropout | ||
dropout: attention weight dropout probability | ||
""" | ||
|
||
input_dim: int | ||
num_att_heads: int | ||
att_weights_dropout: float | ||
dropout: float | ||
|
||
def __post_init__(self) -> None: | ||
super().__post_init__() | ||
assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads" | ||
|
||
|
||
class MultiheadSelfAttentionV1(torch.nn.Module): | ||
""" | ||
Native Multihead Self Attention implementation based on 'Attention Is All You Need' | ||
""" | ||
|
||
def __init__(self, cfg: MultiheadSelfAttentionV1Config): | ||
super().__init__() | ||
self.cfg = cfg | ||
self.num_att_heads = cfg.num_att_heads | ||
self.input_dim = cfg.input_dim | ||
self.dim_heads = self.input_dim // self.num_att_heads | ||
self.layernorm = torch.nn.LayerNorm(cfg.input_dim) | ||
|
||
self.out_proj = torch.nn.Linear(in_features=cfg.input_dim, out_features=cfg.input_dim, bias=True) | ||
self.in_proj = torch.nn.Linear(in_features=cfg.input_dim, out_features=3 * cfg.input_dim, bias=True) | ||
|
||
self.norm = math.sqrt(float(self.input_dim / self.num_att_heads)) | ||
self.softmax = torch.nn.Softmax(-1) | ||
self.dropout = torch.nn.Dropout(cfg.att_weights_dropout) | ||
|
||
def forward(self, qkv: torch.Tensor, key_padding_mask: torch.Tensor): | ||
""" | ||
Computes the forward pass of the MultiheadSelfAttentionV1 module. | ||
Attributes: | ||
qkv (torch.Tensor): The input tensor of shape (B, T, F). | ||
key_padding_mask (torch.Tensor): The key padding mask tensor of shape (batch_dim, num_tokens). | ||
""" | ||
|
||
batch_dim, num_tokens, embed_dim = qkv.shape | ||
x = self.in_proj(qkv) | ||
|
||
hidden_dim = qkv.size(-1) | ||
query, key, value = x.unflatten(-1, (3, hidden_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous() | ||
|
||
query = query.view(batch_dim, -1, self.num_att_heads, self.dim_heads) # [B, T, D//H, D'] | ||
key = key.view(batch_dim, -1, self.num_att_heads, self.dim_heads) # [B, T, D//H, D'] | ||
value = value.view(batch_dim, -1, self.num_att_heads, self.dim_heads) # [B, T, D//H, D'] | ||
|
||
query = torch.transpose(query, 1, 2) # [B, D//H, T, D'] | ||
key = torch.transpose(key, 1, 2) # [B, D//H, T, D'] | ||
value = torch.transpose(value, 1, 2) # [B, D//H, T, D'] | ||
|
||
key = torch.transpose(key, -2, -1) # [B, D//H, D', T] | ||
|
||
dot = torch.matmul(query, key) # [B, D//H, T, T] | ||
dot = dot / self.norm | ||
|
||
if key_padding_mask is not None: | ||
key_padding_mask = key_padding_mask.view(batch_dim, 1, 1, key_padding_mask.size(1)) | ||
dot = dot.masked_fill(key_padding_mask, -float("inf")) | ||
|
||
alpha = self.softmax(dot) # [B, D//H, T, T] | ||
alpha = self.dropout(alpha) | ||
|
||
att_out = torch.matmul(alpha, value) # [B, D//H, T, D'] | ||
att_out = torch.transpose(att_out, 1, 2) # [B, T, D//H, D'] | ||
att_out = att_out.reshape(batch_dim, -1, self.input_dim) # [B, T, D] | ||
att_out = self.out_proj(att_out) | ||
|
||
return att_out, alpha |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from __future__ import annotations | ||
|
||
import torch | ||
|
||
from i6_models.parts.mhsa import MultiheadSelfAttentionV1, MultiheadSelfAttentionV1Config | ||
|
||
|
||
def test_MultiheadSelfAttentionV1(): | ||
""" | ||
Test the functionality of the MultiheadSelfAttentionV1 module. | ||
""" | ||
|
||
def get_output_shape(input_shape, cfg, key_padding_mask=None, need_weights=True): | ||
input_tensor = torch.randn(input_shape) | ||
mhsa = MultiheadSelfAttentionV1(cfg) | ||
output, weights = mhsa(input_tensor, key_padding_mask) | ||
return output.shape, weights.shape | ||
|
||
cfg = MultiheadSelfAttentionV1Config(input_dim=32, num_att_heads=8, att_weights_dropout=0.2, dropout=0.3) | ||
input_shape = [4, 15, 32] # B,T,F | ||
|
||
key_padding_mask = torch.randint(0, 2, (input_shape[0], input_shape[1])) > 0 | ||
|
||
assert get_output_shape(input_shape, cfg, key_padding_mask) == (torch.Size([4, 15, 32]), torch.Size([4, 8, 15, 15])) | ||
|
||
|
||
def test_ComparisonMHSAV1Torch(): | ||
""" | ||
Compares the output of the MultiheadSelfAttentionV1 module with the output of the torch.nn.MultiheadAttention module. | ||
""" | ||
cfg = MultiheadSelfAttentionV1Config(input_dim=32, num_att_heads=8, att_weights_dropout=0, dropout=0) | ||
torch_mhsa = torch.nn.MultiheadAttention(cfg.input_dim, cfg.num_att_heads, dropout=0, batch_first=True) | ||
torch_mhsa.eval() | ||
|
||
mhsav1 = MultiheadSelfAttentionV1(cfg) | ||
mhsav1.eval() | ||
|
||
in_proj_weight = torch_mhsa.in_proj_weight | ||
in_proj_bias = torch_mhsa.in_proj_bias | ||
|
||
out_proj_weight = torch_mhsa.out_proj.weight | ||
out_proj_bias = torch_mhsa.out_proj.bias | ||
|
||
mhsav1.in_proj.weight = in_proj_weight | ||
mhsav1.in_proj.bias = in_proj_bias | ||
mhsav1.out_proj.weight = out_proj_weight | ||
mhsav1.out_proj.bias = out_proj_bias | ||
|
||
input_shape = [4, 15, 32] # B,T,F | ||
input_tensor = torch.randn(input_shape) | ||
|
||
key_padding_mask = torch.randint(0, 2, (input_shape[0], input_shape[1])) > 0 | ||
|
||
mhsav1_out, _ = mhsav1(input_tensor, key_padding_mask) | ||
torch_mhsa_out, _ = torch_mhsa(input_tensor, input_tensor, input_tensor, key_padding_mask) | ||
|
||
assert torch.allclose(mhsav1_out, torch_mhsa_out, atol=1e-08) |