Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MHSA module, Conformer block and encoder with relative PE #55

Merged
merged 29 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
751da88
initial commit
kuacakuaca Jun 24, 2024
873042a
fix
kuacakuaca Jun 25, 2024
e066c60
add flag for broadcasting dropout
kuacakuaca Jul 1, 2024
61b87f5
fix
kuacakuaca Jul 2, 2024
591c777
use kwargs in ffn test
kuacakuaca Jul 4, 2024
56a1d48
black
kuacakuaca Jul 4, 2024
69f733c
fix
kuacakuaca Jul 4, 2024
4df2739
change from torch.matmul to torch.einsum
kuacakuaca Jul 18, 2024
8f9d2c2
move to V2
kuacakuaca Aug 14, 2024
647ee05
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Aug 14, 2024
31fec38
add test for ConformerMHSARelPosV1
kuacakuaca Aug 14, 2024
2b19cd3
add validity check, fix param. init.
kuacakuaca Aug 14, 2024
2bf2c89
black
kuacakuaca Aug 14, 2024
600752e
make dropout model modules
kuacakuaca Aug 21, 2024
7eb61a1
update docstring
kuacakuaca Aug 28, 2024
b8db085
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Sep 3, 2024
83b23c9
adress feedback
kuacakuaca Sep 4, 2024
c2a301f
update docstring
kuacakuaca Sep 4, 2024
a4929dc
Apply suggestions from code review
kuacakuaca Sep 5, 2024
338ff2c
Update i6_models/parts/conformer/mhsa_rel_pos.py
kuacakuaca Sep 5, 2024
c0c706e
black
kuacakuaca Sep 5, 2024
7f4decd
fix & add test case
kuacakuaca Sep 5, 2024
5bf24d2
try fixing
kuacakuaca Sep 5, 2024
0b9fb9d
remove default, update docstring and test case
kuacakuaca Sep 6, 2024
f9cdf6e
fix espnet version
kuacakuaca Sep 6, 2024
1137f24
Update requirements_dev.txt
kuacakuaca Sep 9, 2024
33aa1f1
remove typeguard from requirements_dev
kuacakuaca Sep 9, 2024
88f3702
expand docstring
kuacakuaca Sep 12, 2024
fa3c4ad
make it consistent
kuacakuaca Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions i6_models/assemblies/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .conformer_v1 import *
from .conformer_v2 import *
from .conformer_rel_pos_v1 import *
125 changes: 125 additions & 0 deletions i6_models/assemblies/conformer/conformer_rel_pos_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

__all__ = [
"ConformerRelPosBlockV1Config",
"ConformerRelPosEncoderV1Config",
"ConformerRelPosBlockV1",
"ConformerRelPosEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass, field
from typing import List

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerConvolutionV1,
ConformerConvolutionV1Config,
ConformerMHSARelPosV1,
ConformerMHSARelPosV1Config,
ConformerPositionwiseFeedForwardV1,
ConformerPositionwiseFeedForwardV1Config,
)
from i6_models.assemblies.conformer import ConformerEncoderV2


@dataclass
class ConformerRelPosBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
mhsa_cfg: Configuration for ConformerMHSARelPosV1
conv_cfg: Configuration for ConformerConvolutionV1
modules: List of modules to use for ConformerRelPosBlockV1,
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module
scales: List of scales to apply to the module outputs before the residual connection
"""

# nested configurations
ff_cfg: ConformerPositionwiseFeedForwardV1Config
mhsa_cfg: ConformerMHSARelPosV1Config
conv_cfg: ConformerConvolutionV1Config
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"])
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5])

def __post__init__(self):
super().__post_init__()
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
for module_name in self.modules:
assert module_name in ["ff", "mhsa", "conv"], "module not supported"


class ConformerRelPosBlockV1(nn.Module):
"""
Conformer block module, modifications compared to ConformerBlockV1:
- uses ConfomerMHSARelPosV1 as MHSA module
- enable constructing the block with self-defined module_list as ConformerBlockV2
"""

def __init__(self, cfg: ConformerRelPosBlockV1Config):
"""
:param cfg: conformer block configuration with subunits for the different conformer parts
"""
super().__init__()

modules = []
for module_name in cfg.modules:
if module_name == "ff":
modules.append(ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg))
elif module_name == "mhsa":
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg))
elif module_name == "conv":
modules.append(ConformerConvolutionV1(model_cfg=cfg.conv_cfg))
else:
raise NotImplementedError

self.module_list = nn.ModuleList(modules)
self.scales = cfg.scales
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
for scale, module in zip(self.scales, self.module_list):
if isinstance(module, ConformerMHSARelPosV1):
x = scale * module(x, sequence_mask) + x
else:
x = scale * module(x) + x
x = self.final_layer_norm(x) # [B, T, F]
return x


@dataclass
class ConformerRelPosEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of conformer layers in the conformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for ConformerRelPosBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: ConformerRelPosBlockV1Config


class ConformerRelPosEncoderV1(ConformerEncoderV2):
"""
Modifications compared to ConformerEncoderV2:
- uses multi-headed self-attention with Shaw's relative positional encoding
"""

def __init__(self, cfg: ConformerRelPosEncoderV1Config):
"""
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
"""
super().__init__(cfg)

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
1 change: 1 addition & 0 deletions i6_models/parts/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .convolution import *
from .feedforward import *
from .mhsa import *
from .mhsa_rel_pos import *
from .norm import *
170 changes: 170 additions & 0 deletions i6_models/parts/conformer/mhsa_rel_pos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

__all__ = ["ConformerMHSARelPosV1", "ConformerMHSARelPosV1Config"]

import math
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F

from i6_models.config import ModelConfiguration
from i6_models.util import compat


@dataclass
class ConformerMHSARelPosV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads`
num_att_heads: number of attention heads
rel_pos_clip: maximal relative postion for embedding
att_weights_dropout: attention weights dropout
dropout: multi-headed self attention output dropout
"""

input_dim: int
num_att_heads: int
rel_pos_clip: int
att_weights_dropout: float
dropout: float

def __post_init__(self) -> None:
super().__post_init__()
assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads"


class ConformerMHSARelPosV1(nn.Module):
"""
Conformer multi-headed self-attention module with relative positional encoding proposed by Shaw et al. (cf. https://arxiv.org/abs/1803.02155)
"""

def __init__(self, cfg: ConformerMHSARelPosV1Config):

super().__init__()

self.layernorm = nn.LayerNorm(cfg.input_dim)

self.embed_dim = cfg.input_dim
self.num_heads = cfg.num_att_heads
self.embed_dim_per_head = self.embed_dim // self.num_heads

self.rel_pos_clip = cfg.rel_pos_clip

self.att_weights_dropout = cfg.att_weights_dropout

assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"

# projection matrices
self.q_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim)))
self.k_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim)))
self.v_proj_weight = nn.parameter.Parameter(torch.empty((self.embed_dim, self.embed_dim)))

self.in_proj_bias = nn.parameter.Parameter(torch.empty(3 * self.embed_dim))
albertz marked this conversation as resolved.
Show resolved Hide resolved

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

if self.rel_pos_clip:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just disallow a value of 0 (and below) and then get rid of all the if self.rel_pos_clip blocks?

self.rel_pos_embeddings = nn.parameter.Parameter(
torch.empty(self.rel_pos_clip * 2 + 1, self.embed_dim // self.num_heads)
)
else:
self.register_parameter("rel_pos_embeddings", None)

self.dropout = cfg.dropout

self._reset_parameters() # initialize parameters

def _reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)

if self.rel_pos_clip:
nn.init.normal_(self.rel_pos_embeddings)

nn.init.constant_(self.in_proj_bias, 0.0)

def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
Apply layer norm and multi-head self attention and dropout

:param input_tensor: Input to the self attention of shape (B, T, F)
:param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module
which will be applied/added to dot product, used to mask padded key positions out
"""
output_tensor = self.layernorm(input_tensor) # [B,T,F]

time_dim_size = output_tensor.shape[1]
batch_dim_size = output_tensor.shape[0]

# attention mask
inv_sequence_mask = compat.logical_not(sequence_mask) # [B, T]
mask = (
torch.zeros_like(inv_sequence_mask, dtype=input_tensor.dtype)
.masked_fill(inv_sequence_mask, float("-inf"))
.view(batch_dim_size, 1, 1, time_dim_size)
kuacakuaca marked this conversation as resolved.
Show resolved Hide resolved
.expand(-1, self.num_heads, -1, -1)
) # [B, #heads, 1, T']

# query, key and value sequences
bias_k, bias_q, bias_v = self.in_proj_bias.chunk(3)

query_seq = F.linear(output_tensor, self.q_proj_weight, bias_q) # [B, T, #heads * F']
key_seq = F.linear(output_tensor, self.k_proj_weight, bias_k)
value_seq = F.linear(output_tensor, self.v_proj_weight, bias_v)

q1 = query_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head).transpose(
1, 2
) # [B, #heads, T, F']
k_t = key_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head).permute(
0, 2, 3, 1
) # [B, #heads, F', T']
# attention between query and key sequences
attn1 = torch.matmul(q1, k_t) # [B, #heads, T, T']i

if self.rel_pos_clip:
q2 = (
query_seq.transpose(0, 1)
.contiguous()
.view(time_dim_size, batch_dim_size * self.num_heads, self.embed_dim_per_head)
) # [T, B*#heads, F']

range_vec_q = torch.arange(time_dim_size, device=input_tensor.device)
range_vec_k = torch.arange(time_dim_size, device=input_tensor.device)

distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat, -self.rel_pos_clip, self.rel_pos_clip)

final_mat = distance_mat_clipped + self.rel_pos_clip
# relative positional embeddings
rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', F']

# attention between query sequence and relative positional embeddings
attn2 = torch.matmul(q2, rel_pos_embeddings.transpose(1, 2)).transpose(0, 1) # [B*#heads, T, T']
attn2 = attn2.contiguous().view(
batch_dim_size, self.num_heads, time_dim_size, time_dim_size
) # [B, #heads, T, T']

attn = (attn1 + attn2 + mask) * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T']
else:
attn = (attn1 + mask) * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T']

# softmax and dropout
attn_output_weights = F.dropout(
F.softmax(attn, dim=-1), p=self.att_weights_dropout, training=self.training
Judyxujj marked this conversation as resolved.
Show resolved Hide resolved
) # [B, #heads, T, T']

# sequence of weighted sums over value sequence
v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head).transpose(
1, 2
) # [B, #heads, T', F']
attn_output = (
torch.matmul(attn_output_weights, v).transpose(1, 2).contiguous().view(batch_dim_size, -1, self.embed_dim)
) # [B, T, F]

output_tensor = self.out_proj(attn_output)

output_tensor = F.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]

return output_tensor
Loading