Skip to content

Commit

Permalink
use model cfg, make act configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
mmz33 committed May 17, 2023
1 parent de07f3a commit 2146f81
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
25 changes: 20 additions & 5 deletions i6_models/parts/conformer/convolution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from __future__ import annotations

from dataclasses import dataclass

import torch
from torch import nn
from i6_models.config import ModelConfiguration
from typing import Callable, Union, Any, Type


@dataclass
class ConformerConvolutionV1Config(ModelConfiguration):
channels: int
kernel_size: int
dropout: float
activation: Union[nn.Module, Callable[[torch.Tensor], torch.Tensor]]


class ConformerConvolutionV1(nn.Module):
Expand All @@ -9,15 +22,17 @@ class ConformerConvolutionV1(nn.Module):
see also: https://github.com/espnet/espnet/blob/713e784c0815ebba2053131307db5f00af5159ea/espnet/nets/pytorch_backend/conformer/convolution.py#L13
"""

def __init__(self, channels: int, kernel_size: int, dropout: float = 0.1, activation: nn.Module = nn.SiLU()):
def __init__(self, model_cfg: ConformerConvolutionV1Config):
"""
:param channels: number of channels for conv layers
:param kernel_size: kernel size of conv layers
:param dropout: dropout probability
:param activation: activation function applied after batch norm
:param model_cfg: model configuration for this module
"""
super().__init__()

channels = model_cfg.channels
kernel_size = model_cfg.kernel_size
dropout = model_cfg.dropout
activation = model_cfg.activation

self.pointwise_conv1 = nn.Linear(in_features=channels, out_features=2 * channels)
self.depthwise_conv = nn.Conv1d(
in_channels=channels,
Expand Down
13 changes: 9 additions & 4 deletions tests/test_conformer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from i6_models.parts.conformer.convolution import ConformerConvolutionV1
from i6_models.parts.conformer.convolution import ConformerConvolutionV1, ConformerConvolutionV1Config
import torch
import torch.nn as nn


def test_conformer_convolution_output_shape():
def get_output_shape(batch, time, features, kernel_size=31, dropout=0.1):
def get_output_shape(batch, time, features, kernel_size=31, dropout=0.1, activation=nn.functional.silu):
x = torch.randn(batch, time, features)
conformer_conv_part = ConformerConvolutionV1(channels=features, kernel_size=kernel_size, dropout=dropout)
cfg = ConformerConvolutionV1Config(
channels=features, kernel_size=kernel_size, dropout=dropout, activation=activation
)
conformer_conv_part = ConformerConvolutionV1(cfg)
y = conformer_conv_part(x)
return y.shape

assert get_output_shape(1, 50, 100) == (1, 50, 100) # test with batch size 1
assert get_output_shape(10, 50, 250) == (10, 50, 250)
assert get_output_shape(10, 50, 250, activation=nn.functional.relu) == (10, 50, 250) # different activation
assert get_output_shape(1, 50, 100) == (1, 50, 100) # test with batch size 1
assert get_output_shape(10, 1, 50) == (10, 1, 50) # time dim 1
assert get_output_shape(10, 10, 20, dropout=0.0) == (10, 10, 20) # dropout 0
assert get_output_shape(10, 10, 20, kernel_size=3) == (10, 10, 20) # odd kernel size
Expand Down

0 comments on commit 2146f81

Please sign in to comment.