-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MHSA module, Conformer block and encoder with relative PE #55
Merged
Merged
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
751da88
initial commit
kuacakuaca 873042a
fix
kuacakuaca e066c60
add flag for broadcasting dropout
kuacakuaca 61b87f5
fix
kuacakuaca 591c777
use kwargs in ffn test
kuacakuaca 56a1d48
black
kuacakuaca 69f733c
fix
kuacakuaca 4df2739
change from torch.matmul to torch.einsum
kuacakuaca 8f9d2c2
move to V2
kuacakuaca 647ee05
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca 31fec38
add test for ConformerMHSARelPosV1
kuacakuaca 2b19cd3
add validity check, fix param. init.
kuacakuaca 2bf2c89
black
kuacakuaca 600752e
make dropout model modules
kuacakuaca 7eb61a1
update docstring
kuacakuaca b8db085
Merge remote-tracking branch 'origin/main' into ping_relative_pe_mhsa
kuacakuaca 83b23c9
adress feedback
kuacakuaca c2a301f
update docstring
kuacakuaca a4929dc
Apply suggestions from code review
kuacakuaca 338ff2c
Update i6_models/parts/conformer/mhsa_rel_pos.py
kuacakuaca c0c706e
black
kuacakuaca 7f4decd
fix & add test case
kuacakuaca 5bf24d2
try fixing
kuacakuaca 0b9fb9d
remove default, update docstring and test case
kuacakuaca f9cdf6e
fix espnet version
kuacakuaca 1137f24
Update requirements_dev.txt
kuacakuaca 33aa1f1
remove typeguard from requirements_dev
kuacakuaca 88f3702
expand docstring
kuacakuaca fa3c4ad
make it consistent
kuacakuaca File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .conformer_v1 import * | ||
from .conformer_v2 import * | ||
from .conformer_rel_pos_v1 import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"ConformerRelPosBlockV1Config", | ||
"ConformerRelPosEncoderV1Config", | ||
"ConformerRelPosBlockV1", | ||
"ConformerRelPosEncoderV1", | ||
] | ||
|
||
import torch | ||
from torch import nn | ||
from dataclasses import dataclass, field | ||
from typing import List | ||
|
||
from i6_models.config import ModelConfiguration, ModuleFactoryV1 | ||
from i6_models.parts.conformer import ( | ||
ConformerConvolutionV1, | ||
ConformerConvolutionV1Config, | ||
ConformerMHSARelPosV1, | ||
ConformerMHSARelPosV1Config, | ||
ConformerPositionwiseFeedForwardV1, | ||
ConformerPositionwiseFeedForwardV1Config, | ||
) | ||
from i6_models.assemblies.conformer import ConformerEncoderV2 | ||
|
||
|
||
@dataclass | ||
class ConformerRelPosBlockV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
ff_cfg: Configuration for ConformerPositionwiseFeedForwardV1 | ||
mhsa_cfg: Configuration for ConformerMHSARelPosV1 | ||
conv_cfg: Configuration for ConformerConvolutionV1 | ||
modules: List of modules to use for ConformerRelPosBlockV1, | ||
"ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module | ||
scales: List of scales to apply to the module outputs before the residual connection | ||
""" | ||
|
||
# nested configurations | ||
ff_cfg: ConformerPositionwiseFeedForwardV1Config | ||
mhsa_cfg: ConformerMHSARelPosV1Config | ||
conv_cfg: ConformerConvolutionV1Config | ||
modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"]) | ||
scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5]) | ||
|
||
def __post__init__(self): | ||
super().__post_init__() | ||
assert len(self.modules) == len(self.scales), "modules and scales must have same length" | ||
for module_name in self.modules: | ||
assert module_name in ["ff", "mhsa", "conv"], "module not supported" | ||
|
||
|
||
class ConformerRelPosBlockV1(nn.Module): | ||
""" | ||
Conformer block module, modifications compared to ConformerBlockV1: | ||
- uses ConfomerMHSARelPosV1 as MHSA module | ||
- enable constructing the block with self-defined module_list as ConformerBlockV2 | ||
""" | ||
|
||
def __init__(self, cfg: ConformerRelPosBlockV1Config): | ||
""" | ||
:param cfg: conformer block configuration with subunits for the different conformer parts | ||
""" | ||
super().__init__() | ||
|
||
modules = [] | ||
for module_name in cfg.modules: | ||
if module_name == "ff": | ||
modules.append(ConformerPositionwiseFeedForwardV1(cfg=cfg.ff_cfg)) | ||
elif module_name == "mhsa": | ||
modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg)) | ||
elif module_name == "conv": | ||
modules.append(ConformerConvolutionV1(model_cfg=cfg.conv_cfg)) | ||
else: | ||
raise NotImplementedError | ||
|
||
self.module_list = nn.ModuleList(modules) | ||
self.scales = cfg.scales | ||
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim) | ||
|
||
def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor: | ||
""" | ||
:param x: input tensor of shape [B, T, F] | ||
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T] | ||
:return: torch.Tensor of shape [B, T, F] | ||
""" | ||
for scale, module in zip(self.scales, self.module_list): | ||
if isinstance(module, ConformerMHSARelPosV1): | ||
x = scale * module(x, sequence_mask) + x | ||
else: | ||
x = scale * module(x) + x | ||
x = self.final_layer_norm(x) # [B, T, F] | ||
return x | ||
|
||
|
||
@dataclass | ||
class ConformerRelPosEncoderV1Config(ModelConfiguration): | ||
""" | ||
Attributes: | ||
num_layers: Number of conformer layers in the conformer encoder | ||
frontend: A pair of ConformerFrontend and corresponding config | ||
block_cfg: Configuration for ConformerRelPosBlockV1 | ||
""" | ||
|
||
num_layers: int | ||
|
||
# nested configurations | ||
frontend: ModuleFactoryV1 | ||
block_cfg: ConformerRelPosBlockV1Config | ||
|
||
|
||
class ConformerRelPosEncoderV1(ConformerEncoderV2): | ||
""" | ||
Modifications compared to ConformerEncoderV2: | ||
- uses multi-headed self-attention with Shaw's relative positional encoding | ||
""" | ||
|
||
def __init__(self, cfg: ConformerRelPosEncoderV1Config): | ||
""" | ||
:param cfg: conformer encoder configuration with subunits for frontend and conformer blocks | ||
""" | ||
super().__init__(cfg) | ||
|
||
self.frontend = cfg.frontend() | ||
self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .convolution import * | ||
from .feedforward import * | ||
from .mhsa import * | ||
from .mhsa_rel_pos import * | ||
from .norm import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we later also introduce broadcasting the dropout over the batch dimension, would we introduce another parameter? Or should we here already take care and make it an Enum or so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the RETURNN TF backend, we simply specified a list of axes where to apply broadcasting or not broadcasting. I find this more intuitive and explicit than having here an enum with a list of all possibilities.