Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MHSA module, Conformer block and encoder with relative PE #55

Merged
merged 29 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
751da88
initial commit
kuacakuaca Jun 24, 2024
873042a
fix
kuacakuaca Jun 25, 2024
e066c60
add flag for broadcasting dropout
kuacakuaca Jul 1, 2024
61b87f5
fix
kuacakuaca Jul 2, 2024
591c777
use kwargs in ffn test
kuacakuaca Jul 4, 2024
56a1d48
black
kuacakuaca Jul 4, 2024
69f733c
fix
kuacakuaca Jul 4, 2024
4df2739
change from torch.matmul to torch.einsum
kuacakuaca Jul 18, 2024
8f9d2c2
move to V2
kuacakuaca Aug 14, 2024
647ee05
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Aug 14, 2024
31fec38
add test for ConformerMHSARelPosV1
kuacakuaca Aug 14, 2024
2b19cd3
add validity check, fix param. init.
kuacakuaca Aug 14, 2024
2bf2c89
black
kuacakuaca Aug 14, 2024
600752e
make dropout model modules
kuacakuaca Aug 21, 2024
7eb61a1
update docstring
kuacakuaca Aug 28, 2024
b8db085
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca Sep 3, 2024
83b23c9
adress feedback
kuacakuaca Sep 4, 2024
c2a301f
update docstring
kuacakuaca Sep 4, 2024
a4929dc
Apply suggestions from code review
kuacakuaca Sep 5, 2024
338ff2c
Update i6_models/parts/conformer/mhsa_rel_pos.py
kuacakuaca Sep 5, 2024
c0c706e
black
kuacakuaca Sep 5, 2024
7f4decd
fix & add test case
kuacakuaca Sep 5, 2024
5bf24d2
try fixing
kuacakuaca Sep 5, 2024
0b9fb9d
remove default, update docstring and test case
kuacakuaca Sep 6, 2024
f9cdf6e
fix espnet version
kuacakuaca Sep 6, 2024
1137f24
Update requirements_dev.txt
kuacakuaca Sep 9, 2024
33aa1f1
remove typeguard from requirements_dev
kuacakuaca Sep 9, 2024
88f3702
expand docstring
kuacakuaca Sep 12, 2024
fa3c4ad
make it consistent
kuacakuaca Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions i6_models/assemblies/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .conformer_v1 import *
from .conformer_v2 import *
from .conformer_rel_pos_v1 import *
126 changes: 126 additions & 0 deletions i6_models/assemblies/conformer/conformer_rel_pos_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

__all__ = [
"ConformerRelPosBlockV1Config",
"ConformerRelPosEncoderV1Config",
"ConformerRelPosBlockV1",
"ConformerRelPosEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass, field
from typing import List

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerConvolutionV2,
ConformerConvolutionV2Config,
ConformerMHSARelPosV1,
ConformerMHSARelPosV1Config,
ConformerPositionwiseFeedForwardV2,
ConformerPositionwiseFeedForwardV2Config,
)
from i6_models.assemblies.conformer import ConformerEncoderV2


@dataclass
class ConformerRelPosBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV2
mhsa_cfg: Configuration for ConformerMHSARelPosV1
conv_cfg: Configuration for ConformerConvolutionV2
modules: List of modules to use for ConformerRelPosBlockV1,
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module
scales: List of scales to apply to the module outputs before the residual connection
"""

# nested configurations
ff_cfg: ConformerPositionwiseFeedForwardV2Config
mhsa_cfg: ConformerMHSARelPosV1Config
conv_cfg: ConformerConvolutionV2Config
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"])
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5])

def __post__init__(self):
super().__post_init__()
assert len(self.modules) == len(self.scales), "modules and scales must have same length"
for module_name in self.modules:
assert module_name in ["ff", "mhsa", "conv"], "module not supported"


class ConformerRelPosBlockV1(nn.Module):
"""
Conformer block module, modifications compared to ConformerBlockV1:
- uses ConfomerMHSARelPosV1 as MHSA module
- enable constructing the block with self-defined module_list as ConformerBlockV2
"""

def __init__(self, cfg: ConformerRelPosBlockV1Config):
"""
:param cfg: conformer block configuration with subunits for the different conformer parts
"""
super().__init__()

modules = []
for module_name in cfg.modules:
if module_name == "ff":
modules.append(ConformerPositionwiseFeedForwardV2(cfg=cfg.ff_cfg))
elif module_name == "mhsa":
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg))
elif module_name == "conv":
modules.append(ConformerConvolutionV2(model_cfg=cfg.conv_cfg))
else:
raise NotImplementedError

self.module_list = nn.ModuleList(modules)
self.scales = cfg.scales
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param x: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
for scale, module in zip(self.scales, self.module_list):
if isinstance(module, ConformerMHSARelPosV1):
x = scale * module(x, sequence_mask) + x
else:
x = scale * module(x) + x
x = self.final_layer_norm(x) # [B, T, F]
return x


@dataclass
class ConformerRelPosEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of conformer layers in the conformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for ConformerRelPosBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: ConformerRelPosBlockV1Config


class ConformerRelPosEncoderV1(ConformerEncoderV2):
"""
Modifications compared to ConformerEncoderV2:
- supports Shaw's relative positional encoding using learnable position embeddings
and Transformer-XL style relative PE using fixed sinusoidal or learnable position embeddings
"""

def __init__(self, cfg: ConformerRelPosEncoderV1Config):
"""
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks
"""
super().__init__(cfg)

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])
1 change: 1 addition & 0 deletions i6_models/parts/conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .convolution import *
from .feedforward import *
from .mhsa import *
from .mhsa_rel_pos import *
from .norm import *
84 changes: 82 additions & 2 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from __future__ import annotations

__all__ = ["ConformerConvolutionV1", "ConformerConvolutionV1Config"]
__all__ = [
"ConformerConvolutionV1",
"ConformerConvolutionV1Config",
"ConformerConvolutionV2",
"ConformerConvolutionV2Config",
]

from dataclasses import dataclass
from copy import deepcopy

import torch
from torch import nn
from i6_models.config import ModelConfiguration
from typing import Callable, Union
from typing import Callable, Union, Optional


@dataclass
Expand Down Expand Up @@ -85,3 +90,78 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.pointwise_conv2(tensor)

return self.dropout(tensor)


@dataclass
class ConformerConvolutionV2Config(ConformerConvolutionV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
Allows even kernel size
albertz marked this conversation as resolved.
Show resolved Hide resolved
"""

dropout_broadcast_axes: Optional[str] = None
albertz marked this conversation as resolved.
Show resolved Hide resolved

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes"

albertz marked this conversation as resolved.
Show resolved Hide resolved
assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"
albertz marked this conversation as resolved.
Show resolved Hide resolved


class ConformerConvolutionV2(ConformerConvolutionV1):
"""
Augments ConformerMHSAV1 with dropout broadcasting
"""

def __init__(self, model_cfg: ConformerConvolutionV2Config):
"""
:param model_cfg: model configuration for this module
"""
super().__init__(model_cfg)

self.dropout_broadcast_axes = model_cfg.dropout_broadcast_axes
self.dropout = (
nn.Dropout1d(model_cfg.dropout) if model_cfg.dropout_broadcast_axes else nn.Dropout(model_cfg.dropout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a strong issue, but why do you have dropout here (in convolution) as explicit module but not for Feed-Forward and MHSA? I think this should be consistent. @michelwi @christophmluscher @curufinwe I would prefer to have it explicit, because then it is printed when you print the model.

Copy link
Collaborator Author

@kuacakuaca kuacakuaca Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed the V1 implementations. But I could make them all modules. Anyone against that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK with the suggested change. It would improve readability :)

)

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: input tensor of shape [B,T,F]
:return: torch.Tensor of shape [B,T,F]
"""
tensor = self.layer_norm(tensor)
tensor = self.pointwise_conv1(tensor) # [B,T,2F]
tensor = nn.functional.glu(tensor, dim=-1) # [B,T,F]

# conv layers expect shape [B,F,T] so we have to transpose here
tensor = tensor.transpose(1, 2) # [B,F,T]
tensor = self.depthwise_conv(tensor)

tensor = self.norm(tensor)
tensor = tensor.transpose(1, 2) # transpose back to [B,T,F]

tensor = self.activation(tensor)
tensor = self.pointwise_conv2(tensor)

if self.dropout_broadcast_axes is None:
tensor = self.dropout(tensor)
elif self.dropout_broadcast_axes == "T":
tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = tensor.shape[0]
feature_dim_size = tensor.shape[-1]

tensor = (
self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
albertz marked this conversation as resolved.
Show resolved Hide resolved

return tensor
76 changes: 74 additions & 2 deletions i6_models/parts/conformer/feedforward.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from __future__ import annotations

__all__ = ["ConformerPositionwiseFeedForwardV1", "ConformerPositionwiseFeedForwardV1Config"]
__all__ = [
"ConformerPositionwiseFeedForwardV1",
"ConformerPositionwiseFeedForwardV1Config",
"ConformerPositionwiseFeedForwardV2",
"ConformerPositionwiseFeedForwardV2Config",
]

from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -53,3 +58,70 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = self.linear_out(tensor) # [B,T,F]
tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F]
return tensor


@dataclass
class ConformerPositionwiseFeedForwardV2Config(ConformerPositionwiseFeedForwardV1Config):
"""
New attribute:
dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis
setting to None to disable broadcasting
"""

dropout_broadcast_axes: Optional[str] = None
albertz marked this conversation as resolved.
Show resolved Hide resolved

def check_valid(self):
assert self.dropout_broadcast_axes is None or self.dropout_broadcast_axes in [
"B",
"T",
"BT",
], "invalid value, supported are None, 'B', 'T' and 'BT'"
albertz marked this conversation as resolved.
Show resolved Hide resolved

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1):
"""
Augments ConformerPositionwiseFeedForwardV1 with dropout broadcasting
"""

def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config):
super().__init__(cfg)

self.dropout = nn.Dropout1d(cfg.dropout) if cfg.dropout_broadcast_axes else nn.Dropout(cfg.dropout)
self.dropout_broadcast_axes = cfg.dropout_broadcast_axes

def _broadcast_dropout(self, tensor: torch.Tensor) -> torch.Tensor:
if self.dropout_broadcast_axes is None:
tensor = self.dropout(tensor)
elif self.dropout_broadcast_axes == "T":
tensor = self.dropout(tensor.transpose(1, 2)).transpose(1, 2)
elif self.dropout_broadcast_axes == "B":
tensor = self.dropout(tensor.permute(1, 2, 0)).permute(2, 0, 1)
elif self.dropout_broadcast_axes == "BT":
batch_dim_size = tensor.shape[0]
feature_dim_size = tensor.shape[-1]

tensor = (
self.dropout(tensor.reshape(-1, feature_dim_size).transpose(0, 1))
.transpose(0, 1)
.reshape(batch_dim_size, -1, feature_dim_size)
)
albertz marked this conversation as resolved.
Show resolved Hide resolved
return tensor

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
:param tensor: shape [B,T,F], F=input_dim
:return: shape [B,T,F], F=input_dim
"""
tensor = self.layer_norm(tensor)
tensor = self.linear_ff(tensor) # [B,T,F]
tensor = self.activation(tensor) # [B,T,F]

tensor = self._broadcast_dropout(tensor) # [B,T,F]
tensor = self.linear_out(tensor) # [B,T,F]
tensor = self._broadcast_dropout(tensor) # [B,T,F]

return tensor
Loading
Loading