Skip to content

Commit

Permalink
Initial MHA native implementation (#51)
Browse files Browse the repository at this point in the history
* Initial MHA native implementation

* Test for comparison between native and torch MHA update

* Fixed norm to match torch's implementation

* Removed the inversion of the mask

* Formatting fixes

* Added docstrings

* Renaming MultiheadAttention to MultiheadSelfAttention

* Renaming mha.py to mhsa.py

---------

Co-authored-by: Benedikt Hilmes <[email protected]>
  • Loading branch information
sleepyeldrazi and Atticus1806 authored Jul 1, 2024
1 parent 56bf9fa commit ef22941
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
91 changes: 91 additions & 0 deletions i6_models/parts/mhsa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
__all__ = ["MultiheadSelfAttentionV1", "MultiheadSelfAttentionV1Config"]

import math
from dataclasses import dataclass
import torch

from i6_models.config import ModelConfiguration
from i6_models.util import compat


@dataclass
class MultiheadSelfAttentionV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`
num_att_heads: number of attention heads
att_weights_dropout: attention weights dropout
dropout: attention weight dropout probability
"""

input_dim: int
num_att_heads: int
att_weights_dropout: float
dropout: float

def __post_init__(self) -> None:
super().__post_init__()
assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads"


class MultiheadSelfAttentionV1(torch.nn.Module):
"""
Native Multihead Self Attention implementation based on 'Attention Is All You Need'
"""

def __init__(self, cfg: MultiheadSelfAttentionV1Config):
super().__init__()
self.cfg = cfg
self.num_att_heads = cfg.num_att_heads
self.input_dim = cfg.input_dim
self.dim_heads = self.input_dim // self.num_att_heads
self.layernorm = torch.nn.LayerNorm(cfg.input_dim)

self.out_proj = torch.nn.Linear(in_features=cfg.input_dim, out_features=cfg.input_dim, bias=True)
self.in_proj = torch.nn.Linear(in_features=cfg.input_dim, out_features=3 * cfg.input_dim, bias=True)

self.norm = math.sqrt(float(self.input_dim / self.num_att_heads))
self.softmax = torch.nn.Softmax(-1)
self.dropout = torch.nn.Dropout(cfg.att_weights_dropout)

def forward(self, qkv: torch.Tensor, key_padding_mask: torch.Tensor):
"""
Computes the forward pass of the MultiheadSelfAttentionV1 module.
Attributes:
qkv (torch.Tensor): The input tensor of shape (B, T, F).
key_padding_mask (torch.Tensor): The key padding mask tensor of shape (batch_dim, num_tokens).
"""

batch_dim, num_tokens, embed_dim = qkv.shape
x = self.in_proj(qkv)

hidden_dim = qkv.size(-1)
query, key, value = x.unflatten(-1, (3, hidden_dim)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

query = query.view(batch_dim, -1, self.num_att_heads, self.dim_heads) # [B, T, D//H, D']
key = key.view(batch_dim, -1, self.num_att_heads, self.dim_heads) # [B, T, D//H, D']
value = value.view(batch_dim, -1, self.num_att_heads, self.dim_heads) # [B, T, D//H, D']

query = torch.transpose(query, 1, 2) # [B, D//H, T, D']
key = torch.transpose(key, 1, 2) # [B, D//H, T, D']
value = torch.transpose(value, 1, 2) # [B, D//H, T, D']

key = torch.transpose(key, -2, -1) # [B, D//H, D', T]

dot = torch.matmul(query, key) # [B, D//H, T, T]
dot = dot / self.norm

if key_padding_mask is not None:
key_padding_mask = key_padding_mask.view(batch_dim, 1, 1, key_padding_mask.size(1))
dot = dot.masked_fill(key_padding_mask, -float("inf"))

alpha = self.softmax(dot) # [B, D//H, T, T]
alpha = self.dropout(alpha)

att_out = torch.matmul(alpha, value) # [B, D//H, T, D']
att_out = torch.transpose(att_out, 1, 2) # [B, T, D//H, D']
att_out = att_out.reshape(batch_dim, -1, self.input_dim) # [B, T, D]
att_out = self.out_proj(att_out)

return att_out, alpha
57 changes: 57 additions & 0 deletions tests/test_mha.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import torch

from i6_models.parts.mhsa import MultiheadSelfAttentionV1, MultiheadSelfAttentionV1Config


def test_MultiheadSelfAttentionV1():
"""
Test the functionality of the MultiheadSelfAttentionV1 module.
"""

def get_output_shape(input_shape, cfg, key_padding_mask=None, need_weights=True):
input_tensor = torch.randn(input_shape)
mhsa = MultiheadSelfAttentionV1(cfg)
output, weights = mhsa(input_tensor, key_padding_mask)
return output.shape, weights.shape

cfg = MultiheadSelfAttentionV1Config(input_dim=32, num_att_heads=8, att_weights_dropout=0.2, dropout=0.3)
input_shape = [4, 15, 32] # B,T,F

key_padding_mask = torch.randint(0, 2, (input_shape[0], input_shape[1])) > 0

assert get_output_shape(input_shape, cfg, key_padding_mask) == (torch.Size([4, 15, 32]), torch.Size([4, 8, 15, 15]))


def test_ComparisonMHSAV1Torch():
"""
Compares the output of the MultiheadSelfAttentionV1 module with the output of the torch.nn.MultiheadAttention module.
"""
cfg = MultiheadSelfAttentionV1Config(input_dim=32, num_att_heads=8, att_weights_dropout=0, dropout=0)
torch_mhsa = torch.nn.MultiheadAttention(cfg.input_dim, cfg.num_att_heads, dropout=0, batch_first=True)
torch_mhsa.eval()

mhsav1 = MultiheadSelfAttentionV1(cfg)
mhsav1.eval()

in_proj_weight = torch_mhsa.in_proj_weight
in_proj_bias = torch_mhsa.in_proj_bias

out_proj_weight = torch_mhsa.out_proj.weight
out_proj_bias = torch_mhsa.out_proj.bias

mhsav1.in_proj.weight = in_proj_weight
mhsav1.in_proj.bias = in_proj_bias
mhsav1.out_proj.weight = out_proj_weight
mhsav1.out_proj.bias = out_proj_bias

input_shape = [4, 15, 32] # B,T,F
input_tensor = torch.randn(input_shape)

key_padding_mask = torch.randint(0, 2, (input_shape[0], input_shape[1])) > 0

mhsav1_out, _ = mhsav1(input_tensor, key_padding_mask)
torch_mhsa_out, _ = torch_mhsa(input_tensor, input_tensor, input_tensor, key_padding_mask)

assert torch.allclose(mhsav1_out, torch_mhsa_out, atol=1e-08)

0 comments on commit ef22941

Please sign in to comment.