Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MHSA module, Conformer block and encoder with relative PE #55

Merged
merged 29 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
751da88
initial commit
kuacakuaca Jun 24, 2024
873042a
fix
kuacakuaca Jun 25, 2024
e066c60
add flag for broadcasting dropout
kuacakuaca Jul 1, 2024
61b87f5
fix
kuacakuaca Jul 2, 2024
591c777
use kwargs in ffn test
kuacakuaca Jul 4, 2024
56a1d48
black
kuacakuaca Jul 4, 2024
69f733c
fix
kuacakuaca Jul 4, 2024
4df2739
change from torch.matmul to torch.einsum
kuacakuaca Jul 18, 2024
8f9d2c2
move to V2
kuacakuaca Aug 14, 2024
647ee05
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Aug 14, 2024
31fec38
add test for ConformerMHSARelPosV1
kuacakuaca Aug 14, 2024
2b19cd3
add validity check, fix param. init.
kuacakuaca Aug 14, 2024
2bf2c89
black
kuacakuaca Aug 14, 2024
600752e
make dropout model modules
kuacakuaca Aug 21, 2024
7eb61a1
update docstring
kuacakuaca Aug 28, 2024
b8db085
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Sep 3, 2024
83b23c9
adress feedback
kuacakuaca Sep 4, 2024
c2a301f
update docstring
kuacakuaca Sep 4, 2024
a4929dc
Apply suggestions from code review
kuacakuaca Sep 5, 2024
338ff2c
Update i6_models/parts/conformer/mhsa_rel_pos.py
kuacakuaca Sep 5, 2024
c0c706e
black
kuacakuaca Sep 5, 2024
7f4decd
fix & add test case
kuacakuaca Sep 5, 2024
5bf24d2
try fixing
kuacakuaca Sep 5, 2024
0b9fb9d
remove default, update docstring and test case
kuacakuaca Sep 6, 2024
f9cdf6e
fix espnet version
kuacakuaca Sep 6, 2024
1137f24
Update requirements_dev.txt
kuacakuaca Sep 9, 2024
33aa1f1
remove typeguard from requirements_dev
kuacakuaca Sep 9, 2024
88f3702
expand docstring
kuacakuaca Sep 12, 2024
fa3c4ad
make it consistent
kuacakuaca Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions i6_models/assemblies/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .conformer_v1 import *
from .conformer_v2 import *
from .conformer_rel_pos_v1 import *
126 changes: 126 additions & 0 deletions i6_models/assemblies/conformer/conformer_rel_pos_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

__all__ = [
"ConformerRelPosBlockV1Config",
"ConformerRelPosEncoderV1Config",
"ConformerRelPosBlockV1",
"ConformerRelPosEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass, field
from typing import List

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerConvolutionV2,
ConformerConvolutionV2Config,
ConformerMHSARelPosV1,
ConformerMHSARelPosV1Config,
ConformerPositionwiseFeedForwardV2,
ConformerPositionwiseFeedForwardV2Config,
)
from i6_models.assemblies.conformer import ConformerEncoderV2


@dataclass
class ConformerRelPosBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV2
mhsa_cfg: Configuration for ConformerMHSARelPosV1
conv_cfg: Configuration for ConformerConvolutionV2
modules: List of modules to use for ConformerRelPosBlockV1,
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module
scales: List of scales to apply to the module outputs before the residual connection
"""

# nested configurations
ff_cfg: ConformerPositionwiseFeedForwardV2Config
mhsa_cfg: ConformerMHSARelPosV1Config
conv_cfg: ConformerConvolutionV2Config
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"])
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5])

def __post__init__(self):
super().__post_init__()
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
for module_name in self.modules:
assert module_name in ["ff", "mhsa", "conv"], "module not supported"


class ConformerRelPosBlockV1(nn.Module):
"""
Conformer block module, modifications compared to ConformerBlockV1:
- uses ConfomerMHSARelPosV1 as MHSA module
- enable constructing the block with self-defined module_list as ConformerBlockV2
"""

def __init__(self, cfg: ConformerRelPosBlockV1Config):
"""
:param cfg: conformer block configuration with subunits for the different conformer parts
"""
super().__init__()

modules = []
for module_name in cfg.modules:
if module_name == "ff":
modules.append(ConformerPositionwiseFeedForwardV2(cfg=cfg.ff_cfg))
elif module_name == "mhsa":
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg))
elif module_name == "conv":
modules.append(ConformerConvolutionV2(model_cfg=cfg.conv_cfg))
else:
raise NotImplementedError

self.module_list = nn.ModuleList(modules)
self.scales = cfg.scales
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
for scale, module in zip(self.scales, self.module_list):
if isinstance(module, ConformerMHSARelPosV1):
x = scale * module(x, sequence_mask) + x
else:
x = scale * module(x) + x
x = self.final_layer_norm(x) # [B, T, F]
return x


@dataclass
class ConformerRelPosEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of conformer layers in the conformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for ConformerRelPosBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: ConformerRelPosBlockV1Config


class ConformerRelPosEncoderV1(ConformerEncoderV2):
"""
Modifications compared to ConformerEncoderV2:
- supports Shaw's relative positional encoding using learnable position embeddings
and Transformer-XL style relative PE using fixed sinusoidal or learnable position embeddings
"""

def __init__(self, cfg: ConformerRelPosEncoderV1Config):
"""
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
"""
super().__init__(cfg)

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
1 change: 1 addition & 0 deletions i6_models/parts/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .convolution import *
from .feedforward import *
from .mhsa import *
from .mhsa_rel_pos import *
from .norm import *
69 changes: 67 additions & 2 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

__all__ = ["ConformerConvolutionV1", "ConformerConvolutionV1Config"]
__all__ = [
"ConformerConvolutionV1",
"ConformerConvolutionV1Config",
"ConformerConvolutionV2",
"ConformerConvolutionV2Config",
]

from dataclasses import dataclass
from copy import deepcopy
from typing import Callable, Union, Optional, Literal

import torch
from torch import nn
from i6_models.config import ModelConfiguration
from typing import Callable, Union
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -85,3 +91,62 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.pointwise_conv2(tensor)

return self.dropout(tensor)


@dataclass
class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
Allows even kernel size
albertz marked this conversation as resolved.
Show resolved Hide resolved
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"

albertz marked this conversation as resolved.
Show resolved Hide resolved
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"


class ConformerConvolutionV2(ConformerConvolutionV1):
"""
Augments ConformerMHSAV1 with dropout broadcasting
"""

def __init__(self, model_cfg: ConformerConvolutionV2Config):
"""
:param model_cfg: model configuration for this module
"""
super().__init__(model_cfg)

self.dropout = BroadcastDropout(model_cfg.dropout, dropout_broadcast_axes=model_cfg.dropout_broadcast_axes)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: input tensor of shape [B,T,F]
:return: torch.Tensor of shape [B,T,F]
"""
tensor = self.layer_norm(tensor)
tensor = self.pointwise_conv1(tensor) # [B,T,2F]
tensor = nn.functional.glu(tensor, dim=-1) # [B,T,F]

# conv layers expect shape [B,F,T] so we have to transpose here
tensor = tensor.transpose(1, 2) # [B,F,T]
tensor = self.depthwise_conv(tensor)

tensor = self.norm(tensor)
tensor = tensor.transpose(1, 2) # transpose back to [B,T,F]

tensor = self.activation(tensor)
tensor = self.pointwise_conv2(tensor)

tensor = self.dropout(tensor)

return tensor
59 changes: 57 additions & 2 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from __future__ import annotations

__all__ = ["ConformerPositionwiseFeedForwardV1", "ConformerPositionwiseFeedForwardV1Config"]
__all__ = [
"ConformerPositionwiseFeedForwardV1",
"ConformerPositionwiseFeedForwardV1Config",
"ConformerPositionwiseFeedForwardV2",
"ConformerPositionwiseFeedForwardV2Config",
]

from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional, Literal

import torch
from torch import nn

from i6_models.config import ModelConfiguration
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -53,3 +59,52 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.linear_out(tensor) # [B,T,F]
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
return tensor


@dataclass
class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None

def check_valid(self):
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1):
"""
Augments ConformerPositionwiseFeedForwardV1 with dropout broadcasting
"""

def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config):
super().__init__(cfg)

self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: shape [B,T,F], F=input_dim
:return: shape [B,T,F], F=input_dim
"""
tensor = self.layer_norm(tensor)
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]

tensor = self.dropout(tensor) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
tensor = self.dropout(tensor) # [B,T,F]

return tensor
59 changes: 58 additions & 1 deletion i6_models/parts/conformer/mhsa.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config"]
__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config", "ConformerMHSAV2", "ConformerMHSAV2Config"]

from dataclasses import dataclass
from typing import Optional, Literal
import torch

from i6_models.config import ModelConfiguration
from i6_models.util import compat
from i6_models.parts.dropout import BroadcastDropout


@dataclass
Expand Down Expand Up @@ -60,3 +63,57 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to
output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F]

return output_tensor


@dataclass
class ConformerMHSAV2Config(ConformerMHSAV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None

def check_valid(self):
assert self.dropout_broadcast_axes in [
None,
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConformerMHSAV2(ConformerMHSAV1):
"""
Augments ConformerMHSAV1 with dropout broadcasting
"""

def __init__(self, cfg: ConformerMHSAV2Config):

super().__init__(cfg)

self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes)

def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
Apply layer norm and multi-head self attention and dropout

:param input_tensor: Input to the self attention of shape (B, T, F)
:param sequence_mask: Bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module
which will be applied/added to dot product, used to mask padded key positions out
"""
inv_sequence_mask = compat.logical_not(sequence_mask)
output_tensor = self.layernorm(input_tensor) # [B,T,F]

output_tensor, _ = self.mhsa(
output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False
) # [B,T,F]

output_tensor = self.dropout(output_tensor)

return output_tensor # [B,T,F]
Loading
Loading