Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add E-branchformer to i6_models #27

Merged
merged 13 commits into from
Oct 26, 2023
1 change: 1 addition & 0 deletions i6_models/assemblies/e_branchformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .e_branchformer_v1 import *
124 changes: 124 additions & 0 deletions i6_models/assemblies/e_branchformer/e_branchformer_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from __future__ import annotations

__all__ = [
"EbranchformerBlockV1Config",
"EbranchformerBlockV1",
"EbranchformerEncoderV1Config",
"EbranchformerEncoderV1",
]

import torch
from torch import nn
from dataclasses import dataclass
from typing import Tuple

from i6_models.config import ModelConfiguration, ModuleFactoryV1
from i6_models.parts.conformer import (
ConformerMHSAV1 as MHSAV1,
ConformerMHSAV1Config as MHSAV1Config,
ConformerPositionwiseFeedForwardV1 as PositionwiseFeedForwardV1,
ConformerPositionwiseFeedForwardV1Config as PositionwiseFeedForwardV1Config,
)
from i6_models.parts.e_branchformer import (
ConvolutionalGatingMLPV1Config,
ConvolutionalGatingMLPV1,
MergerV1Config,
MergerV1,
)


@dataclass
class EbranchformerBlockV1Config(ModelConfiguration):
"""
Attributes:
ff_cfg: Configuration for PositionwiseFeedForwardV1 module
mhsa_cfg: Configuration for MHSAV1 module
cgmlp_cfg: Configuration for ConvolutionalGatingMLPV1 module
merger_cfg: Configuration for MergerV1 module
"""

ff_cfg: PositionwiseFeedForwardV1Config
mhsa_cfg: MHSAV1Config
cgmlp_cfg: ConvolutionalGatingMLPV1Config
merger_cfg: MergerV1Config


class EbranchformerBlockV1(nn.Module):
"""
Ebranchformer block module
"""

def __init__(self, cfg: EbranchformerBlockV1Config):
"""
:param cfg: e-branchformer block configuration with subunits for the different e-branchformer parts
"""
super().__init__()
self.ff_1 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
self.mhsa = MHSAV1(cfg=cfg.mhsa_cfg)
self.cgmlp = ConvolutionalGatingMLPV1(model_cfg=cfg.cgmlp_cfg)
self.merger = MergerV1(model_cfg=cfg.merger_cfg)
self.ff_2 = PositionwiseFeedForwardV1(cfg=cfg.ff_cfg)
self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim)

def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor:
"""
:param tensor: input tensor of shape [B, T, F]
:param sequence_mask: mask tensor where 0 defines positions within the sequence and 1 outside, shape: [B, T]
:return: torch.Tensor of shape [B, T, F]
"""
x = 0.5 * self.ff1(x) + x # [B, T, F]
x_1 = self.mhsa(x, sequence_mask) # [B, T, F]
x_2 = self.cgmlp(x) # [B, T, F]
x = self.merger(x_1, x_2) + x # [B, T, F]
x = 0.5 * self.ff2(x) + x # [B, T, F]
x = self.final_layer_norm(x) # [B, T, F]
return x


class EbranchformerEncoderV1Config(ModelConfiguration):
"""
Attributes:
num_layers: Number of e-branchformer layers in the e-branchformer encoder
frontend: A pair of ConformerFrontend and corresponding config
block_cfg: Configuration for EbranchformerBlockV1
"""

num_layers: int

# nested configurations
frontend: ModuleFactoryV1
block_cfg: EbranchformerBlockV1Config


class EbranchformerEncoderV1(nn.Module):
"""
Implementation of the Branchformer with Enhanced merging (short e-branchformer), as in the original publication.
The model consists of a frontend and a stack of N e-branchformer blocks.
C.f. https://arxiv.org/pdf/2210.00077.pdf
"""

def __init__(self, cfg: EbranchformerEncoderV1Config):
"""
:param cfg: e-branchformer encoder configuration with subunits for frontend and e-branchformer blocks
"""
super().__init__()

self.frontend = cfg.frontend()
self.module_list = torch.nn.ModuleList([EbranchformerBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)])

def forward(self, data_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param data_tensor: input tensor of shape [B, T', F']
:param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T']
:return: (output, out_seq_mask)
where output is torch.Tensor of shape [B, T, F],
out_seq_mask is a torch.Tensor of shape [B, T]

F': input feature dim, F: internal and output feature dim
T': data time dim, T: down-sampled time dim (internal time dim)
"""
x, sequence_mask = self.frontend(data_tensor, sequence_mask) # [B, T, F]
for module in self.module_list:
x = module(x, sequence_mask) # [B, T, F]

return x, sequence_mask
2 changes: 2 additions & 0 deletions i6_models/parts/e_branchformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .cgmlp import *
from .merge import *
81 changes: 81 additions & 0 deletions i6_models/parts/e_branchformer/cgmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

__all__ = ["ConvolutionalGatingMLPV1Config", "ConvolutionalGatingMLPV1"]

from dataclasses import dataclass
from typing import Callable

import torch
from torch import nn

from i6_models.config import ModelConfiguration


@dataclass
class ConvolutionalGatingMLPV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dimension
hidden_dim: hidden dimension (normally set to 6*input_dim as suggested by the paper)
kernel_size: kernel size of the depthwise convolution layer
dropout: dropout probability
activation: activation function
"""

input_dim: int
hidden_dim: int
kernel_size: int
dropout: float
activation: Callable[[torch.Tensor], torch.Tensor] = nn.functional.gelu

def check_valid(self):
assert self.kernel_size % 2 == 1, "ConvolutionalGatingMLPV1 only supports odd kernel sizes"
assert self.hidden_dim % 2 == 0, "ConvolutionalGatingMLPV1 only supports even hidden_dim"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class ConvolutionalGatingMLPV1(nn.Module):
"""Convolutional Gating MLP (cgMLP)."""

def __init__(self, model_cfg: ConvolutionalGatingMLPV1Config):
super().__init__()

self.layer_norm_input = nn.LayerNorm(model_cfg.input_dim)
self.linear_ff = nn.Linear(in_features=model_cfg.input_dim, out_features=model_cfg.hidden_dim, bias=True)
self.activation = model_cfg.activation
self.layer_norm_csgu = nn.LayerNorm(model_cfg.hidden_dim // 2)
self.depthwise_conv = nn.Conv1d(
in_channels=model_cfg.hidden_dim // 2,
curufinwe marked this conversation as resolved.
Show resolved Hide resolved
out_channels=model_cfg.hidden_dim // 2,
kernel_size=model_cfg.kernel_size,
padding=(model_cfg.kernel_size - 1) // 2,
groups=model_cfg.hidden_dim // 2,
)
self.linear_out = nn.Linear(in_features=model_cfg.hidden_dim // 2, out_features=model_cfg.input_dim, bias=True)
self.dropout = model_cfg.dropout

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: shape [B, T, F], F=input_dim
:return: shape [B, T, F], F=input_dim
"""
x = self.layer_norm_input(x) # [B, T, F]
x = self.linear_ff(x) # [B, T, F']
x = self.activation(x)

# convolutional spatial gating unit (csgu)
x_1, x_2 = x.chunk(2, dim=-1) # [B, T, F'//2], [B, T, F'//2]
x_2 = self.layer_norm_csgu(x_2)
# conv layers expect shape [B, F, T] so we have to transpose here
x_2 = x_2.transpose(1, 2) # [B, F'//2, T]
x_2 = self.depthwise_conv(x_2) # [B, F'//2, T]
x_2 = x_2.transpose(1, 2) # [B, T, F'//2]
x = x_1 * x_2 # [B, T, F'//2]
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
Judyxujj marked this conversation as resolved.
Show resolved Hide resolved

x = self.linear_out(x) # [B, T, F]
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
return x
62 changes: 62 additions & 0 deletions i6_models/parts/e_branchformer/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

__all__ = ["MergerV1Config", "MergerV1"]

from dataclasses import dataclass

import torch
from torch import nn

from i6_models.config import ModelConfiguration


@dataclass
class MergerV1Config(ModelConfiguration):
"""
Attributes:
input_dim: input dimension
kernel_size: kernel size of the depthwise convolution layer
dropout: dropout probability
"""

input_dim: int
kernel_size: int
curufinwe marked this conversation as resolved.
Show resolved Hide resolved
dropout: float

def check_valid(self):
assert self.kernel_size % 2 == 1, "MergerV1 only supports odd kernel sizes"

def __post__init__(self):
super().__post_init__()
self.check_valid()


class MergerV1(nn.Module):
def __init__(self, model_cfg: MergerV1Config):
"""
The merge module to merge the outputs of local extractor and global extractor
Here we take the best variant from the E-branchformer paper (Fig. 3c), refer to
https://arxiv.org/abs/2210.00077 for more merge module variants
"""
super().__init__()

self.depthwise_conv = nn.Conv1d(
in_channels=model_cfg.input_dim * 2,
out_channels=model_cfg.input_dim * 2,
kernel_size=model_cfg.kernel_size,
padding=(model_cfg.kernel_size - 1) // 2,
groups=model_cfg.input_dim * 2,
)
self.linear_ff = nn.Linear(in_features=2 * model_cfg.input_dim, out_features=model_cfg.input_dim, bias=True)
self.dropout = model_cfg.dropout

def forward(self, x_1: torch.Tensor, x_2: torch.Tensor) -> torch.Tensor:
x_concat = torch.cat([x_1, x_2], dim=-1) # [B, T, 2F]
# conv layers expect shape [B, F, T] so we have to transpose here
x = x_concat.transpose(1, 2) # [B, 2F, T]
x = self.depthwise_conv(x)
x = x.transpose(1, 2) # [B, T, 2F]
x = x + x_concat
x = self.linear_ff(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
return x
36 changes: 36 additions & 0 deletions tests/test_e_branchformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from itertools import product

import torch
from torch import nn

from i6_models.parts.e_branchformer.cgmlp import ConvolutionalGatingMLPV1Config, ConvolutionalGatingMLPV1
from i6_models.parts.e_branchformer.merge import MergerV1Config, MergerV1


def test_ConvolutionalGatingMLPV1():
Judyxujj marked this conversation as resolved.
Show resolved Hide resolved
def get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation):
input_dim = input_shape[-1]
cfg = ConvolutionalGatingMLPV1Config(input_dim, hidden_dim, kernel_size, dropout, activation)
e_branchformer_cgmlp_part = ConvolutionalGatingMLPV1(cfg)
x = torch.randn(input_shape)
y = e_branchformer_cgmlp_part(x)
return y.shape

for input_shape, hidden_dim, kernel_size, dropout, activation in product(
[(100, 5, 20), (200, 30, 10)], [120, 60], [9, 15], [0.1, 0.3], [nn.functional.gelu, nn.functional.relu]
):
assert get_output_shape(input_shape, hidden_dim, kernel_size, dropout, activation) == input_shape


def test_MergerV1():
def get_output_shape(input_shape, kernel_size, dropout):
input_dim = input_shape[-1]
cfg = MergerV1Config(input_dim, kernel_size, dropout)
e_branchformer_merge_part = MergerV1(cfg)
tensor_local = torch.randn(input_shape)
tensor_global = torch.randn(input_shape)
y = e_branchformer_merge_part(tensor_local, tensor_global)
return y.shape

for input_shape, kernel_size, dropout in product([(100, 5, 20), (200, 30, 10)], [15, 31], [0.1, 0.3]):
assert get_output_shape(input_shape, kernel_size, dropout) == input_shape
Loading