diff --git a/i6_models/assemblies/conformer/__init__.py b/i6_models/assemblies/conformer/__init__.py index a5e0003b..3ac14e46 100644 --- a/i6_models/assemblies/conformer/__init__.py +++ b/i6_models/assemblies/conformer/__init__.py @@ -1,2 +1,3 @@ from .conformer_v1 import * from .conformer_v2 import * +from .conformer_rel_pos_v1 import * diff --git a/i6_models/assemblies/conformer/conformer_rel_pos_v1.py b/i6_models/assemblies/conformer/conformer_rel_pos_v1.py new file mode 100644 index 00000000..3f6c0524 --- /dev/null +++ b/i6_models/assemblies/conformer/conformer_rel_pos_v1.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +__all__ = [ + "ConformerRelPosBlockV1Config", + "ConformerRelPosEncoderV1Config", + "ConformerRelPosBlockV1", + "ConformerRelPosEncoderV1", +] + +import torch +from torch import nn +from dataclasses import dataclass, field +from typing import List + +from i6_models.config import ModelConfiguration, ModuleFactoryV1 +from i6_models.parts.conformer import ( + ConformerConvolutionV2, + ConformerConvolutionV2Config, + ConformerMHSARelPosV1, + ConformerMHSARelPosV1Config, + ConformerPositionwiseFeedForwardV2, + ConformerPositionwiseFeedForwardV2Config, +) +from i6_models.assemblies.conformer import ConformerEncoderV2 + + +@dataclass +class ConformerRelPosBlockV1Config(ModelConfiguration): + """ + Attributes: + ff_cfg: Configuration for ConformerPositionwiseFeedForwardV2 + mhsa_cfg: Configuration for ConformerMHSARelPosV1 + conv_cfg: Configuration for ConformerConvolutionV2 + modules: List of modules to use for ConformerRelPosBlockV1, + "ff" for feed forward module, "mhsa" for multi-head self attention module, "conv" for conv module + scales: List of scales to apply to the module outputs before the residual connection + """ + + # nested configurations + ff_cfg: ConformerPositionwiseFeedForwardV2Config + mhsa_cfg: ConformerMHSARelPosV1Config + conv_cfg: ConformerConvolutionV2Config + modules: List[str] = field(default_factory=lambda: ["ff", "mhsa", "conv", "ff"]) + scales: List[float] = field(default_factory=lambda: [0.5, 1.0, 1.0, 0.5]) + + def __post__init__(self): + super().__post_init__() + assert len(self.modules) == len(self.scales), "modules and scales must have same length" + for module_name in self.modules: + assert module_name in ["ff", "mhsa", "conv"], "module not supported" + + +class ConformerRelPosBlockV1(nn.Module): + """ + Conformer block module, modifications compared to ConformerBlockV1: + - uses ConfomerMHSARelPosV1 as MHSA module + - enable constructing the block with self-defined module_list as ConformerBlockV2 + """ + + def __init__(self, cfg: ConformerRelPosBlockV1Config): + """ + :param cfg: conformer block configuration with subunits for the different conformer parts + """ + super().__init__() + + modules = [] + for module_name in cfg.modules: + if module_name == "ff": + modules.append(ConformerPositionwiseFeedForwardV2(cfg=cfg.ff_cfg)) + elif module_name == "mhsa": + modules.append(ConformerMHSARelPosV1(cfg=cfg.mhsa_cfg)) + elif module_name == "conv": + modules.append(ConformerConvolutionV2(model_cfg=cfg.conv_cfg)) + else: + raise NotImplementedError + + self.module_list = nn.ModuleList(modules) + self.scales = cfg.scales + self.final_layer_norm = torch.nn.LayerNorm(cfg.ff_cfg.input_dim) + + def forward(self, x: torch.Tensor, /, sequence_mask: torch.Tensor) -> torch.Tensor: + """ + :param x: input tensor of shape [B, T, F] + :param sequence_mask: mask tensor where 1 defines positions within the sequence and 0 outside, shape: [B, T] + :return: torch.Tensor of shape [B, T, F] + """ + for scale, module in zip(self.scales, self.module_list): + if isinstance(module, ConformerMHSARelPosV1): + x = scale * module(x, sequence_mask) + x + else: + x = scale * module(x) + x + x = self.final_layer_norm(x) # [B, T, F] + return x + + +@dataclass +class ConformerRelPosEncoderV1Config(ModelConfiguration): + """ + Attributes: + num_layers: Number of conformer layers in the conformer encoder + frontend: A pair of ConformerFrontend and corresponding config + block_cfg: Configuration for ConformerRelPosBlockV1 + """ + + num_layers: int + + # nested configurations + frontend: ModuleFactoryV1 + block_cfg: ConformerRelPosBlockV1Config + + +class ConformerRelPosEncoderV1(ConformerEncoderV2): + """ + Modifications compared to ConformerEncoderV2: + - supports Shaw's relative positional encoding using learnable position embeddings + and Transformer-XL style relative PE using fixed sinusoidal or learnable position embeddings + """ + + def __init__(self, cfg: ConformerRelPosEncoderV1Config): + """ + :param cfg: conformer encoder configuration with subunits for frontend and conformer blocks + """ + super().__init__(cfg) + + self.frontend = cfg.frontend() + self.module_list = torch.nn.ModuleList([ConformerRelPosBlockV1(cfg.block_cfg) for _ in range(cfg.num_layers)]) diff --git a/i6_models/parts/conformer/__init__.py b/i6_models/parts/conformer/__init__.py index cbd34e35..4354b9de 100644 --- a/i6_models/parts/conformer/__init__.py +++ b/i6_models/parts/conformer/__init__.py @@ -1,4 +1,5 @@ from .convolution import * from .feedforward import * from .mhsa import * +from .mhsa_rel_pos import * from .norm import * diff --git a/i6_models/parts/conformer/convolution.py b/i6_models/parts/conformer/convolution.py index 26b5b7d1..a7566855 100644 --- a/i6_models/parts/conformer/convolution.py +++ b/i6_models/parts/conformer/convolution.py @@ -1,14 +1,20 @@ from __future__ import annotations -__all__ = ["ConformerConvolutionV1", "ConformerConvolutionV1Config"] +__all__ = [ + "ConformerConvolutionV1", + "ConformerConvolutionV1Config", + "ConformerConvolutionV2", + "ConformerConvolutionV2Config", +] from dataclasses import dataclass from copy import deepcopy +from typing import Callable, Union, Optional, Literal import torch from torch import nn from i6_models.config import ModelConfiguration -from typing import Callable, Union +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -85,3 +91,61 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: tensor = self.pointwise_conv2(tensor) return self.dropout(tensor) + + +@dataclass +class ConformerConvolutionV2Config(ConformerConvolutionV1Config): + """ + New attribute: + dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis + setting to None to disable broadcasting + """ + + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] + + def check_valid(self): + assert self.kernel_size % 2 == 1, "ConformerConvolutionV1 only supports odd kernel sizes" + + assert self.dropout_broadcast_axes in [ + None, + "B", + "T", + "BT", + ], "invalid value, supported are None, 'B', 'T' and 'BT'" + + +class ConformerConvolutionV2(ConformerConvolutionV1): + """ + Augments ConformerMHSAV1 with dropout broadcasting + """ + + def __init__(self, model_cfg: ConformerConvolutionV2Config): + """ + :param model_cfg: model configuration for this module + """ + super().__init__(model_cfg) + + self.dropout = BroadcastDropout(model_cfg.dropout, dropout_broadcast_axes=model_cfg.dropout_broadcast_axes) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: input tensor of shape [B,T,F] + :return: torch.Tensor of shape [B,T,F] + """ + tensor = self.layer_norm(tensor) + tensor = self.pointwise_conv1(tensor) # [B,T,2F] + tensor = nn.functional.glu(tensor, dim=-1) # [B,T,F] + + # conv layers expect shape [B,F,T] so we have to transpose here + tensor = tensor.transpose(1, 2) # [B,F,T] + tensor = self.depthwise_conv(tensor) + + tensor = self.norm(tensor) + tensor = tensor.transpose(1, 2) # transpose back to [B,T,F] + + tensor = self.activation(tensor) + tensor = self.pointwise_conv2(tensor) + + tensor = self.dropout(tensor) + + return tensor diff --git a/i6_models/parts/conformer/feedforward.py b/i6_models/parts/conformer/feedforward.py index 323988a3..9d7dda4b 100644 --- a/i6_models/parts/conformer/feedforward.py +++ b/i6_models/parts/conformer/feedforward.py @@ -1,14 +1,20 @@ from __future__ import annotations -__all__ = ["ConformerPositionwiseFeedForwardV1", "ConformerPositionwiseFeedForwardV1Config"] +__all__ = [ + "ConformerPositionwiseFeedForwardV1", + "ConformerPositionwiseFeedForwardV1Config", + "ConformerPositionwiseFeedForwardV2", + "ConformerPositionwiseFeedForwardV2Config", +] from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional, Literal import torch from torch import nn from i6_models.config import ModelConfiguration +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -53,3 +59,57 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: tensor = self.linear_out(tensor) # [B,T,F] tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training) # [B,T,F] return tensor + + +@dataclass +class ConformerPositionwiseFeedForwardV2Config(ModelConfiguration): + """ + New attribute: + dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis + setting to None to disable broadcasting + Default value for `activation` removed + """ + + input_dim: int + hidden_dim: int + dropout: float + activation: Callable[[torch.Tensor], torch.Tensor] + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] + + def check_valid(self): + assert self.dropout_broadcast_axes in [ + None, + "B", + "T", + "BT", + ], "invalid value, supported are None, 'B', 'T' and 'BT'" + + def __post__init__(self): + super().__post_init__() + self.check_valid() + + +class ConformerPositionwiseFeedForwardV2(ConformerPositionwiseFeedForwardV1): + """ + Augments ConformerPositionwiseFeedForwardV1 with dropout broadcasting + """ + + def __init__(self, cfg: ConformerPositionwiseFeedForwardV2Config): + super().__init__(cfg) + + self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: shape [B,T,F], F=input_dim + :return: shape [B,T,F], F=input_dim + """ + tensor = self.layer_norm(tensor) + tensor = self.linear_ff(tensor) # [B,T,F] + tensor = self.activation(tensor) # [B,T,F] + + tensor = self.dropout(tensor) # [B,T,F] + tensor = self.linear_out(tensor) # [B,T,F] + tensor = self.dropout(tensor) # [B,T,F] + + return tensor diff --git a/i6_models/parts/conformer/mhsa.py b/i6_models/parts/conformer/mhsa.py index c25a178d..2a67defe 100644 --- a/i6_models/parts/conformer/mhsa.py +++ b/i6_models/parts/conformer/mhsa.py @@ -1,11 +1,14 @@ from __future__ import annotations -__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config"] +__all__ = ["ConformerMHSAV1", "ConformerMHSAV1Config", "ConformerMHSAV2", "ConformerMHSAV2Config"] + from dataclasses import dataclass +from typing import Optional, Literal import torch from i6_models.config import ModelConfiguration from i6_models.util import compat +from i6_models.parts.dropout import BroadcastDropout @dataclass @@ -60,3 +63,57 @@ def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> to output_tensor = torch.nn.functional.dropout(output_tensor, p=self.dropout, training=self.training) # [B,T,F] return output_tensor + + +@dataclass +class ConformerMHSAV2Config(ConformerMHSAV1Config): + """ + New attribute: + dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis + setting to None to disable broadcasting + """ + + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] + + def check_valid(self): + assert self.dropout_broadcast_axes in [ + None, + "B", + "T", + "BT", + ], "invalid value, supported are None, 'B', 'T' and 'BT'" + + def __post__init__(self): + super().__post_init__() + self.check_valid() + + +class ConformerMHSAV2(ConformerMHSAV1): + """ + Augments ConformerMHSAV1 with dropout broadcasting + """ + + def __init__(self, cfg: ConformerMHSAV2Config): + + super().__init__(cfg) + + self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes) + + def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor: + """ + Apply layer norm and multi-head self attention and dropout + + :param input_tensor: Input to the self attention of shape (B, T, F) + :param sequence_mask: Bool mask of shape (B, T), True signals within sequence, False outside, will be inverted to match the torch.nn.MultiheadAttention module + which will be applied/added to dot product, used to mask padded key positions out + """ + inv_sequence_mask = compat.logical_not(sequence_mask) + output_tensor = self.layernorm(input_tensor) # [B,T,F] + + output_tensor, _ = self.mhsa( + output_tensor, output_tensor, output_tensor, key_padding_mask=inv_sequence_mask, need_weights=False + ) # [B,T,F] + + output_tensor = self.dropout(output_tensor) + + return output_tensor # [B,T,F] diff --git a/i6_models/parts/conformer/mhsa_rel_pos.py b/i6_models/parts/conformer/mhsa_rel_pos.py new file mode 100644 index 00000000..12e174c1 --- /dev/null +++ b/i6_models/parts/conformer/mhsa_rel_pos.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +__all__ = ["ConformerMHSARelPosV1", "ConformerMHSARelPosV1Config"] + +from dataclasses import dataclass +import math +from typing import Optional, Literal + +import torch +from torch import nn +import torch.nn.functional as F + +from i6_models.config import ModelConfiguration +from i6_models.util import compat +from i6_models.parts.dropout import BroadcastDropout + + +@dataclass +class ConformerMHSARelPosV1Config(ModelConfiguration): + """ + Attributes: + input_dim: input dim and total dimension for query/key and value projections, should be divisible by `num_att_heads` + num_att_heads: number of attention heads + with_bias: whether to add bias to qkv and output linear projections + att_weights_dropout: attention weights dropout + learnable_pos_emb: whether to use learnable relative positional embeddings instead of fixed sinusoidal ones + rel_pos_clip: maximal relative postion for embedding + with_linear_pos: whether to linearly transform the positional embeddings + separate_pos_emb_per_head: whether to create head-dependent positional embeddings + with_pos_bias: whether to add additional position bias terms to the attention scores + pos_emb_dropout: dropout for the positional embeddings + dropout: multi-headed self attention output dropout + dropout_broadcast_axes: string of axes to which dropout is broadcast, e.g. "T" for broadcasting to the time axis + setting to None to disable broadcasting + """ + + input_dim: int + num_att_heads: int + with_bias: bool + att_weights_dropout: float + learnable_pos_emb: bool + rel_pos_clip: Optional[int] + with_linear_pos: bool + with_pos_bias: bool + separate_pos_emb_per_head: bool + pos_emb_dropout: float + dropout: float + dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] + + def __post_init__(self) -> None: + super().__post_init__() + assert self.input_dim % self.num_att_heads == 0, "input_dim must be divisible by num_att_heads" + assert self.dropout_broadcast_axes in [ + None, + "B", + "T", + "BT", + ], "invalid value, supported are None, 'B', 'T' and 'BT'" + + +class ConformerMHSARelPosV1(nn.Module): + """ + Conformer multi-headed self-attention module supporting + - self-attention with relative positional encoding proposed by Shaw et al. (cf. https://arxiv.org/abs/1803.02155) + * learnable_pos_emb = True + * with_pos_bias = False + * with_linear_pos = False + * separate_pos_emb_per_head = False (RETURNN default) + * with_bias = False (RETURNN default) + - and self-attention with Transformer-XL style relative PE by Dai et al. + (cf. https://arxiv.org/abs/1901.02860, https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py, + https://github.com/espnet/espnet/blob/master/espnet2/asr_transducer/encoder/modules/attention.py#L9) + * learnable_pos_emb = False + * with_pos_bias = True + * with_linear_pos = False (paper implementation) / with_linear_pos = True (ESPnet default) + * separate_pos_emb_per_head = False (paper implementation) / separate_pos_emb_per_head = True (ESPnet default) + * with_bias = False (paper implementation) / with_bias = True (ESPnet default) + """ + + def __init__(self, cfg: ConformerMHSARelPosV1Config): + + super().__init__() + + self.layernorm = nn.LayerNorm(cfg.input_dim) + + self.embed_dim = cfg.input_dim + self.num_heads = cfg.num_att_heads + self.embed_dim_per_head = self.embed_dim // self.num_heads + + self.learnable_pos_emb = cfg.learnable_pos_emb + self.rel_pos_clip = cfg.rel_pos_clip + self.separate_pos_emb_per_head = cfg.separate_pos_emb_per_head + self.with_pos_bias = cfg.with_pos_bias + self.pos_emb_dropout = nn.Dropout(cfg.pos_emb_dropout) + + assert not self.learnable_pos_emb or self.rel_pos_clip + + self.att_weights_dropout = nn.Dropout(cfg.att_weights_dropout) + + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + + # projection matrices + self.qkv_proj = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=cfg.with_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=cfg.with_bias) + + self.register_parameter("rel_pos_embeddings", None) + self.register_parameter("pos_bias_u", None) + self.register_parameter("pos_bias_v", None) + + self.pos_emb_dim = ( + self.embed_dim if cfg.with_linear_pos or cfg.separate_pos_emb_per_head else self.embed_dim_per_head + ) + if self.learnable_pos_emb: + self.rel_pos_embeddings = nn.parameter.Parameter(torch.empty(self.rel_pos_clip * 2 + 1, self.pos_emb_dim)) + if cfg.with_linear_pos: + self.linear_pos = nn.Linear( + self.pos_emb_dim, + self.embed_dim if cfg.separate_pos_emb_per_head else self.embed_dim_per_head, + bias=False, + ) + else: + self.linear_pos = nn.Identity() + + if self.with_pos_bias: + self.pos_bias_u = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head)) + self.pos_bias_v = nn.parameter.Parameter(torch.empty(self.num_heads, self.embed_dim_per_head)) + + self.dropout = BroadcastDropout(cfg.dropout, dropout_broadcast_axes=cfg.dropout_broadcast_axes) + + self._reset_parameters() + + def _reset_parameters(self): + if self.learnable_pos_emb: + nn.init.xavier_normal_(self.rel_pos_embeddings) + if self.with_pos_bias: + # init taken from espnet default + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward(self, input_tensor: torch.Tensor, sequence_mask: torch.Tensor) -> torch.Tensor: + """ + Apply layer norm and multi-head self attention and dropout + + :param input_tensor: Input to the self attention of shape (B, T, F) + :param sequence_mask: bool mask of shape (B, T), True signals within sequence, False outside + """ + output_tensor = self.layernorm(input_tensor) # [B, T, F] + + time_dim_size = output_tensor.shape[1] + batch_dim_size = output_tensor.shape[0] + + # attention mask + # T: query seq. length, T' key/value seg length; T = T' if same input tensor + inv_sequence_mask = compat.logical_not(sequence_mask) # [B, T'] + mask = ( + torch.zeros_like(inv_sequence_mask, dtype=input_tensor.dtype) + .masked_fill(inv_sequence_mask, float("-inf")) + .reshape(batch_dim_size, 1, 1, time_dim_size) + ) # [B, 1, 1, T'] + + # query, key and value sequences + query_seq, key_seq, value_seq = self.qkv_proj(output_tensor).chunk(3, dim=-1) # [B, T, #heads * F'] + q = query_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, #heads, F'] + k = key_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T', #heads, F'] + + if self.learnable_pos_emb: + pos_seq_q = torch.arange(time_dim_size, device=input_tensor.device) + pos_seq_k = torch.arange(time_dim_size, device=input_tensor.device) + + distance_mat = pos_seq_k[None, :] - pos_seq_q[:, None] + distance_mat_clipped = torch.clamp(distance_mat, -self.rel_pos_clip, self.rel_pos_clip) + + final_mat = distance_mat_clipped + self.rel_pos_clip + + rel_pos_embeddings = self.rel_pos_embeddings[final_mat] # [T, T', pos_emb_dim] + else: + rel_pos_embeddings = self._sinusoidal_pe( + torch.arange(time_dim_size - 1, -time_dim_size, -1, device=input_tensor.device, dtype=torch.float32), + self.pos_emb_dim, + ).view( + 1, 2 * time_dim_size - 1, self.pos_emb_dim + ) # [1, T+T'-1, pos_emb_dim] + + # dropout relative positional embeddings + rel_pos_embeddings = self.pos_emb_dropout( + rel_pos_embeddings + ) # [T, T', pos_emb_dim] or [1, T+T'-1, pos_emb_dim] + rel_pos_embeddings = rel_pos_embeddings.unsqueeze(2) # [T, T', 1, pos_emb_dim] or [1, T+T'-1, 1, pos_emb_dim] + + # linear transformation or identity + rel_pos_embeddings = self.linear_pos(rel_pos_embeddings) # [T, T', 1, F'|F] or [1, T+T'-1, 1, F'|F] + + if self.separate_pos_emb_per_head: + rel_pos_embeddings = rel_pos_embeddings.squeeze(2).reshape( + *rel_pos_embeddings.shape[:2], -1, self.embed_dim_per_head + ) # [T, T', #heads, F'] or [1, T+T'-1, #heads, F'] + + q_with_bias_u = q + self.pos_bias_u if self.with_pos_bias else q # [B, T, #heads, F'] + q_with_bias_v = q + self.pos_bias_v if self.with_pos_bias else q + + # attention matrix a and c + attn_ac = torch.einsum("bihf, bjhf -> bhij", q_with_bias_u, k) # [B, #heads, T, T'] + + # attention matrix b and d + attn_bd = torch.einsum( + "bihf, ijhf -> bhij", q_with_bias_v, rel_pos_embeddings + ) # [B, #heads, T, T'] or [B, #heads, T, T+T'+1] + + if not self.learnable_pos_emb: + attn_bd = self._rel_shift_bhij(attn_bd, k_len=time_dim_size) # [B, #heads, T, T'] + + attn = attn_ac + attn_bd + mask # [B, #heads, T, T'] + attn_scaled = attn * (math.sqrt(1.0 / float(self.embed_dim_per_head))) # [B, #heads, T, T'] + + # softmax and dropout + attn_output_weights = self.att_weights_dropout(F.softmax(attn_scaled, dim=-1)) # [B, #heads, T, T'] + + # sequence of weighted sums over value sequence + v = value_seq.view(batch_dim_size, -1, self.num_heads, self.embed_dim_per_head) # [B, T, H, F'] + attn_output = torch.einsum("bhij, bjhf -> bihf", attn_output_weights, v).reshape( + batch_dim_size, -1, self.embed_dim + ) + + output_tensor = self.out_proj(attn_output) + + output_tensor = self.dropout(output_tensor) + + return output_tensor # [B,T,F] + + @staticmethod + def _rel_shift_bhij(x, k_len=None): + """ + :param x: input tensor of shape (B, H, T, L) to apply left shift + :k_len: length of the key squence + """ + x_shape = x.shape + + x = torch.nn.functional.pad(x, (1, 0)) # [B, H, T, L+1] + x = x.reshape(x_shape[0], x_shape[1], x_shape[3] + 1, x_shape[2]) # [B, H, L+1, T] + x = x[:, :, 1:] # [B, H, L, T] + x = x.reshape(x_shape) # [B, H, T, L]] + + return x[:, :, :, :k_len] if k_len else x # [B, H, T, T'] + + @staticmethod + def _sinusoidal_pe(pos_seq: torch.Tensor, embed_dim: int): + """ + :param pos_seq: 1-D position sequence for which to compute embeddings + :param embed_dim: embedding dimension + """ + inv_freq = 1 / (10000 ** (torch.arange(0.0, embed_dim, 2.0, device=pos_seq.device) / embed_dim)) + + sinusoid_input = torch.outer(pos_seq, inv_freq) + + pos_emb = torch.zeros(pos_seq.shape[0], embed_dim) + + pos_emb[:, 0::2] = sinusoid_input.sin() + pos_emb[:, 1::2] = sinusoid_input.cos() + + return pos_emb diff --git a/i6_models/parts/dropout.py b/i6_models/parts/dropout.py new file mode 100644 index 00000000..f2ea463f --- /dev/null +++ b/i6_models/parts/dropout.py @@ -0,0 +1,57 @@ +from typing import Optional, Literal + +import torch +from torch import nn + + +class BroadcastDropout(nn.Module): + """ + customized dropout module supporting dropout broadcasting + supported variants are: + - no broadcasting (default): dropout_broadcast_axes=None + - broadcast over the batch axis: dropout_broadcast_axes='B' + - broadcast over the time axis: dropout_broadcast_axes='T' + - broadcast over the batch and time axes: dropout_broadcast_axes='BT' + """ + + def __init__(self, p: float, dropout_broadcast_axes: Optional[Literal["B", "T", "BT"]] = None): + super().__init__() + + self.p = p + assert dropout_broadcast_axes in [ + None, + "B", + "T", + "BT", + ], "invalid value, supported are None, 'B', 'T' and 'BT'" + self.dropout_broadcast_axes = dropout_broadcast_axes + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + """ + :param tensor: input tensor of shape [B, T, F] + :return: tensor of shape [B, T, F] + """ + if self.dropout_broadcast_axes is None: + tensor = torch.nn.functional.dropout(tensor, p=self.p, training=self.training) + elif self.dropout_broadcast_axes == "T": # [B, T, F] -> [B, F, T] -> [B, T, F] + # torch.nn.functional.dropout1d expects a 3D tensor and broadcasts in the last dimension. + tensor = torch.nn.functional.dropout1d(tensor.transpose(1, 2), p=self.p, training=self.training).transpose( + 1, 2 + ) + elif self.dropout_broadcast_axes == "B": # [B, T, F] -> [T, F, B] -> [B, T, F] + tensor = torch.nn.functional.dropout1d(tensor.permute(1, 2, 0), p=self.p, training=self.training).permute( + 2, 0, 1 + ) + elif self.dropout_broadcast_axes == "BT": # [B, T, F] -> [B*T, F] -> [F, B*T] -> [B*T, F] -> [B, T, F] + batch_dim_size = tensor.shape[0] + feature_dim_size = tensor.shape[-1] + + tensor = ( + torch.nn.functional.dropout1d( + tensor.reshape(-1, feature_dim_size).transpose(0, 1), p=self.p, training=self.training + ) + .transpose(0, 1) + .reshape(batch_dim_size, -1, feature_dim_size) + ) + + return tensor diff --git a/requirements_dev.txt b/requirements_dev.txt index 60dd5826..8a3bc595 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -1,2 +1,3 @@ onnx -onnxruntime \ No newline at end of file +onnxruntime +espnet @ git+https://github.com/espnet/espnet.git@9d607c1e7af91c2f611f7dd181ad091649f215c4 diff --git a/tests/test_conformer.py b/tests/test_conformer.py index e6061673..29693c82 100644 --- a/tests/test_conformer.py +++ b/tests/test_conformer.py @@ -37,7 +37,9 @@ def get_output_shape(batch, time, features, norm=None, kernel_size=31, dropout=0 def test_ConformerPositionwiseFeedForwardV1(): def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation): x = torch.randn(input_shape) - cfg = ConformerPositionwiseFeedForwardV1Config(input_dim, hidden_dim, dropout, activation) + cfg = ConformerPositionwiseFeedForwardV1Config( + input_dim=input_dim, hidden_dim=hidden_dim, dropout=dropout, activation=activation + ) conf_ffn_part = ConformerPositionwiseFeedForwardV1(cfg) y = conf_ffn_part(x) return y.shape diff --git a/tests/test_conformer_rel_pos.py b/tests/test_conformer_rel_pos.py new file mode 100644 index 00000000..3aafb32c --- /dev/null +++ b/tests/test_conformer_rel_pos.py @@ -0,0 +1,206 @@ +from __future__ import annotations +from itertools import product + +import torch +from torch import nn + +from i6_models.parts.conformer.convolution import ConformerConvolutionV2, ConformerConvolutionV2Config +from i6_models.parts.conformer.feedforward import ( + ConformerPositionwiseFeedForwardV2, + ConformerPositionwiseFeedForwardV2Config, +) +from i6_models.parts.conformer.mhsa_rel_pos import ConformerMHSARelPosV1Config, ConformerMHSARelPosV1 +from i6_models.parts.conformer.norm import LayerNormNC + + +def test_ConformerConvolutionV2(): + def get_output_shape( + batch, + time, + features, + norm=None, + kernel_size=31, + dropout=0.1, + activation=nn.functional.silu, + dropout_broadcast_axes=None, + ): + x = torch.randn(batch, time, features) + if norm is None: + norm = nn.BatchNorm1d(features) + cfg = ConformerConvolutionV2Config( + channels=features, + kernel_size=kernel_size, + dropout=dropout, + activation=activation, + norm=norm, + dropout_broadcast_axes=dropout_broadcast_axes, + ) + conformer_conv_part = ConformerConvolutionV2(cfg) + y = conformer_conv_part(x) + return y.shape + + assert get_output_shape(10, 50, 250) == (10, 50, 250) + assert get_output_shape(10, 50, 250, activation=nn.functional.relu) == (10, 50, 250) # different activation + assert get_output_shape(10, 50, 250, norm=LayerNormNC(250)) == (10, 50, 250) # different norm + assert get_output_shape(1, 50, 100) == (1, 50, 100) # test with batch size 1 + assert get_output_shape(10, 1, 50) == (10, 1, 50) # time dim 1 + assert get_output_shape(10, 10, 20, dropout=0.0) == (10, 10, 20) # dropout 0 + assert get_output_shape(10, 10, 20, kernel_size=3) == (10, 10, 20) # odd kernel size + assert get_output_shape(5, 480, 512, dropout_broadcast_axes="T") == (5, 480, 512) # dropout broadcast to T + assert get_output_shape(5, 480, 512, dropout_broadcast_axes="BT") == (5, 480, 512) # dropout broadcast to BT + + +def test_ConformerPositionwiseFeedForwardV2(): + def get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation, dropout_broadcast_axes=None): + x = torch.randn(input_shape) + cfg = ConformerPositionwiseFeedForwardV2Config( + input_dim=input_dim, + hidden_dim=hidden_dim, + dropout=dropout, + activation=activation, + dropout_broadcast_axes=dropout_broadcast_axes, + ) + conf_ffn_part = ConformerPositionwiseFeedForwardV2(cfg) + y = conf_ffn_part(x) + return y.shape + + for input_dim, hidden_dim, dropout, activation, dropout_broadcast_axes in product( + [10, 20], [100, 200], [0.1, 0.3], [nn.functional.silu, nn.functional.relu], [None, "B", "T", "BT"] + ): + input_shape = (10, 100, input_dim) + assert get_output_shape(input_shape, input_dim, hidden_dim, dropout, activation) == input_shape + + +def test_ConformerMHSARelPosV1(): + def get_output_shape( + input_shape, + seq_len, + input_dim, + with_bias=True, + num_att_heads=8, + att_weights_dropout=0.1, + dropout=0.1, + learnable_pos_emb=True, + with_linear_pos=False, + separate_pos_emb_per_head=False, + rel_pos_clip=16, + with_pos_bias=False, + pos_emb_dropout=0.0, + dropout_broadcast_axes=None, + ): + assert len(input_shape) == 3 and input_shape[-1] == input_dim + + cfg = ConformerMHSARelPosV1Config( + input_dim=input_dim, + num_att_heads=num_att_heads, + with_bias=with_bias, + att_weights_dropout=att_weights_dropout, + dropout=dropout, + learnable_pos_emb=learnable_pos_emb, + with_linear_pos=with_linear_pos, + separate_pos_emb_per_head=separate_pos_emb_per_head, + rel_pos_clip=rel_pos_clip, + with_pos_bias=with_pos_bias, + pos_emb_dropout=pos_emb_dropout, + dropout_broadcast_axes=dropout_broadcast_axes, + ) + conf_mhsa_rel_pos = ConformerMHSARelPosV1(cfg) + input_tensor = torch.randn(input_shape) + sequence_mask = torch.less(torch.arange(input_shape[1])[None, :], torch.tensor(seq_len)[:, None]) + + output = conf_mhsa_rel_pos(input_tensor, sequence_mask) + + return list(output.shape) + + # with key padding mask + input_shape = [4, 15, 32] # B,T,F + seq_len = [15, 12, 10, 15] + + for learnable_pos_emb, with_pos_bias, pos_emb_dropout, with_linear_pos, separate_pos_emb_per_head in product( + [True, False], [True, False], [0.0, 0.1], [True, False], [True, False] + ): + assert get_output_shape( + input_shape, + seq_len, + 32, + learnable_pos_emb=learnable_pos_emb, + with_pos_bias=with_pos_bias, + pos_emb_dropout=pos_emb_dropout, + with_linear_pos=with_linear_pos, + separate_pos_emb_per_head=separate_pos_emb_per_head, + ) == [4, 15, 32] + + +def test_ConformerMHSARelPosV1_against_Espnet(): + from espnet2.asr_transducer.encoder.modules.attention import RelPositionMultiHeadedAttention + from espnet2.asr_transducer.encoder.modules.positional_encoding import RelPositionalEncoding + + num_heads = 4 + embed_size = 256 + dropout_rate = 0.1 + batch_dim_size = 4 + time_dim_size = 50 + seq_len = torch.Tensor([50, 10, 20, 40]) + sequence_mask = torch.less(torch.arange(time_dim_size)[None, :], seq_len[:, None]) + + espnet_mhsa_module = RelPositionMultiHeadedAttention( + num_heads=num_heads, embed_size=embed_size, dropout_rate=dropout_rate + ) + espnet_mhsa_module.eval() + espnet_pos_enc_module = RelPositionalEncoding(embed_size, dropout_rate=dropout_rate) + espnet_pos_enc_module.eval() + + cfg = ConformerMHSARelPosV1Config( + input_dim=embed_size, + num_att_heads=num_heads, + with_bias=True, + att_weights_dropout=dropout_rate, + dropout=dropout_rate, + learnable_pos_emb=False, + with_linear_pos=True, + separate_pos_emb_per_head=True, + rel_pos_clip=None, + with_pos_bias=True, + pos_emb_dropout=dropout_rate, + dropout_broadcast_axes=None, + ) + own_mhsa_module = ConformerMHSARelPosV1(cfg) + own_mhsa_module.eval() + own_mhsa_module.linear_pos = espnet_mhsa_module.linear_pos + own_mhsa_module.pos_bias_u = espnet_mhsa_module.pos_bias_u + own_mhsa_module.pos_bias_v = espnet_mhsa_module.pos_bias_v + own_mhsa_module.out_proj = espnet_mhsa_module.linear_out + own_mhsa_module.qkv_proj.weight = nn.Parameter( + torch.cat( + [ + espnet_mhsa_module.linear_q.weight, + espnet_mhsa_module.linear_k.weight, + espnet_mhsa_module.linear_v.weight, + ], + dim=0, + ) + ) + own_mhsa_module.qkv_proj.bias = nn.Parameter( + torch.cat( + [espnet_mhsa_module.linear_q.bias, espnet_mhsa_module.linear_k.bias, espnet_mhsa_module.linear_v.bias], + dim=0, + ) + ) + + input_tensor = torch.rand((batch_dim_size, time_dim_size, embed_size)) + inv_sequence_mask = torch.logical_not(sequence_mask) + + input_tensor_layernorm = own_mhsa_module.layernorm(input_tensor) + + espnet_pos_enc = espnet_pos_enc_module(input_tensor_layernorm) + espnet_output_tensor = espnet_mhsa_module( + query=input_tensor_layernorm, + key=input_tensor_layernorm, + value=input_tensor_layernorm, + pos_enc=espnet_pos_enc, + mask=inv_sequence_mask, + ) + + own_output_tensor = own_mhsa_module(input_tensor, sequence_mask=sequence_mask) + + assert torch.allclose(espnet_output_tensor, own_output_tensor, rtol=1e-03, atol=1e-6)