Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MHSA module, Conformer block and encoder with relative PE #55

Merged
merged 29 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
751da88
initial commit
kuacakuaca Jun 24, 2024
873042a
fix
kuacakuaca Jun 25, 2024
e066c60
add flag for broadcasting dropout
kuacakuaca Jul 1, 2024
61b87f5
fix
kuacakuaca Jul 2, 2024
591c777
use kwargs in ffn test
kuacakuaca Jul 4, 2024
56a1d48
black
kuacakuaca Jul 4, 2024
69f733c
fix
kuacakuaca Jul 4, 2024
4df2739
change from torch.matmul to torch.einsum
kuacakuaca Jul 18, 2024
8f9d2c2
move to V2
kuacakuaca Aug 14, 2024
647ee05
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Aug 14, 2024
31fec38
add test for ConformerMHSARelPosV1
kuacakuaca Aug 14, 2024
2b19cd3
add validity check, fix param. init.
kuacakuaca Aug 14, 2024
2bf2c89
black
kuacakuaca Aug 14, 2024
600752e
make dropout model modules
kuacakuaca Aug 21, 2024
7eb61a1
update docstring
kuacakuaca Aug 28, 2024
b8db085
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Sep 3, 2024
83b23c9
adress feedback
kuacakuaca Sep 4, 2024
c2a301f
update docstring
kuacakuaca Sep 4, 2024
a4929dc
Apply suggestions from code review
kuacakuaca Sep 5, 2024
338ff2c
Update i6_models/parts/conformer/mhsa_rel_pos.py
kuacakuaca Sep 5, 2024
c0c706e
black
kuacakuaca Sep 5, 2024
7f4decd
fix & add test case
kuacakuaca Sep 5, 2024
5bf24d2
try fixing
kuacakuaca Sep 5, 2024
0b9fb9d
remove default, update docstring and test case
kuacakuaca Sep 6, 2024
f9cdf6e
fix espnet version
kuacakuaca Sep 6, 2024
1137f24
Update requirements_dev.txt
kuacakuaca Sep 9, 2024
33aa1f1
remove typeguard from requirements_dev
kuacakuaca Sep 9, 2024
88f3702
expand docstring
kuacakuaca Sep 12, 2024
fa3c4ad
make it consistent
kuacakuaca Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions i6_models/assemblies/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .conformer_v1 import *
from .conformer_v2 import *
from .conformer_rel_pos_v1 import *
125 changes: 125 additions & 0 deletions i6_models/assemblies/conformer/conformer_rel_pos_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

__all__ = [
"ConformerRelPosBlockV1Config",
"ConformerRelPosEncoderV1Config",
"ConformerRelPosBlockV1",
"ConformerRelPosEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass, field
from typing import List

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerConvolutionV1,
ConformerConvolutionV1Config,
ConformerMHSARelPosV1,
ConformerMHSARelPosV1Config,
ConformerPositionwiseFeedForwardV1,
ConformerPositionwiseFeedForwardV1Config,
)
from i6_models.assemblies.conformer import ConformerEncoderV2


@dataclass
class ConformerRelPosBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1
mhsa_cfg: Configuration for ConformerMHSARelPosV1
conv_cfg: Configuration for ConformerConvolutionV1
modules: List of modules to use for ConformerRelPosBlockV1,
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module
scales: List of scales to apply to the module outputs before the residual connection
"""

# nested configurations
ff_cfg: ConformerPositionwiseFeedForwardV1Config
mhsa_cfg: ConformerMHSARelPosV1Config
conv_cfg: ConformerConvolutionV1Config
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"])
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5])

def __post__init__(self):
super().__post_init__()
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
for module_name in self.modules:
assert module_name in ["ff", "mhsa", "conv"], "module not supported"


class ConformerRelPosBlockV1(nn.Module):
"""
Conformer block module, modifications compared to ConformerBlockV1:
- uses ConfomerMHSARelPosV1 as MHSA module
- enable constructing the block with self-defined module_list as ConformerBlockV2
"""

def __init__(self, cfg: ConformerRelPosBlockV1Config):
"""
:param cfg: conformer block configuration with subunits for the different conformer parts
"""
super().__init__()

modules = []
for module_name in cfg.modules:
if module_name == "ff":
modules.append(ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg))
elif module_name == "mhsa":
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg))
elif module_name == "conv":
modules.append(ConformerConvolutionV1(model_cfg=cfg.conv_cfg))
else:
raise NotImplementedError

self.module_list = nn.ModuleList(modules)
self.scales = cfg.scales
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
for scale, module in zip(self.scales, self.module_list):
if isinstance(module, ConformerMHSARelPosV1):
x = scale * module(x, sequence_mask) + x
else:
x = scale * module(x) + x
x = self.final_layer_norm(x) # [B, T, F]
return x


@dataclass
class ConformerRelPosEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of conformer layers in the conformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for ConformerRelPosBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: ConformerRelPosBlockV1Config


class ConformerRelPosEncoderV1(ConformerEncoderV2):
"""
Modifications compared to ConformerEncoderV2:
- uses multi-headed self-attention with Shaw's relative positional encoding
"""

def __init__(self, cfg: ConformerRelPosEncoderV1Config):
"""
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
"""
super().__init__(cfg)

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
1 change: 1 addition & 0 deletions i6_models/parts/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .convolution import *
from .feedforward import *
from .mhsa import *
from .mhsa_rel_pos import *
from .norm import *
12 changes: 10 additions & 2 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class ConformerConvolutionV1Config(ModelConfiguration):
dropout: dropout probability
activation: activation function applied after normalization
norm: normalization layer with input of shape [N,C,T]
broadcast_dropout: whether to broadcast dropout on the feature axis to time axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we later also introduce broadcasting the dropout over the batch dimension, would we introduce another parameter? Or should we here already take care and make it an Enum or so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the RETURNN TF backend, we simply specified a list of axes where to apply broadcasting or not broadcasting. I find this more intuitive and explicit than having here an enum with a list of all possibilities.

"""

channels: int
kernel_size: int
dropout: float
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
norm: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]
broadcast_dropout: bool = False

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"
Expand Down Expand Up @@ -62,7 +64,8 @@ def __init__(self, model_cfg: ConformerConvolutionV1Config):
self.pointwise_conv2 = nn.Linear(in_features=model_cfg.channels, out_features=model_cfg.channels)
self.layer_norm = nn.LayerNorm(model_cfg.channels)
self.norm = deepcopy(model_cfg.norm)
self.dropout = nn.Dropout(model_cfg.dropout)
self.dropout = nn.Dropout1d(model_cfg.dropout) if model_cfg.broadcast_dropout else nn.Dropout(model_cfg.dropout)
self.broadcast_dropout = model_cfg.broadcast_dropout
self.activation = model_cfg.activation

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
Expand All @@ -84,4 +87,9 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.activation(tensor)
tensor = self.pointwise_conv2(tensor)

return self.dropout(tensor)
if self.broadcast_dropout:
tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
else:
tensor = self.dropout(tensor)

return tensor
19 changes: 16 additions & 3 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class ConformerPositionwiseFeedForwardV1Config(ModelConfiguration):
input_dim: input dimension
hidden_dim: hidden dimension (normally set to 4*input_dim as suggested by the paper)
dropout: dropout probability
broadcast_dropout: whether to broadcast dropout on the feature axis to time axis
activation: activation function
"""

input_dim: int
hidden_dim: int
dropout: float
broadcast_dropout: bool = False
activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.silu


Expand All @@ -40,6 +42,7 @@ def __init__(self, cfg: ConformerPositionwiseFeedForwardV1Config):
self.activation = cfg.activation
self.linear_out = nn.Linear(in_features=cfg.hidden_dim, out_features=cfg.input_dim, bias=True)
self.dropout = cfg.dropout
self.broadcast_dropout = cfg.broadcast_dropout

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -49,7 +52,17 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.layer_norm(tensor)
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]

if self.broadcast_dropout:
tensor = nn.functional.dropout1d(tensor.transpose(1, 2), p=self.dropout, training=self.training).transpose(
1, 2
)
tensor = self.linear_out(tensor)
tensor = nn.functional.dropout1d(tensor.transpose(1, 2), p=self.dropout, training=self.training).transpose(
1, 2
)
else:
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
return tensor
13 changes: 12 additions & 1 deletion i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ class ConformerMHSAV1Config(ModelConfiguration):
num_att_heads: number of attention heads
att_weights_dropout: attention weights dropout
dropout: multi-headed self attention output dropout
broadcast_dropout: whether to broadcast dropout on the feature axis to time axis
"""

input_dim: int
num_att_heads: int
att_weights_dropout: float
dropout: float
broadcast_dropout: bool = False

def __post_init__(self) -> None:
super().__post_init__()
Expand All @@ -42,6 +44,7 @@ def __init__(self, cfg: ConformerMHSAV1Config):
cfg.input_dim, cfg.num_att_heads, dropout=cfg.att_weights_dropout, batch_first=True
)
self.dropout = cfg.dropout
self.broadcast_dropout = cfg.broadcast_dropout

def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -57,6 +60,14 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
output_tensor, _ = self.mhsa(
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
) # [B,T,F]
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]

if self.broadcast_dropout:
output_tensor = torch.nn.functional.dropout1d(
output_tensor.transpose(1, 2), p=self.dropout, training=self.training
).transpose(1, 2)
else:
output_tensor = torch.nn.functional.dropout(
output_tensor, p=self.dropout, training=self.training
) # [B,T,F]

return output_tensor
Loading
Loading