-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add part: lstm block #66
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,64 @@ | ||||||||
__all__ = ["LstmBlockV1Config", "LstmBlockV1"] | ||||||||
|
||||||||
from dataclasses import dataclass | ||||||||
import torch | ||||||||
from torch import nn | ||||||||
from typing import Dict, Union | ||||||||
|
||||||||
from i6_models.config import ModelConfiguration | ||||||||
|
||||||||
|
||||||||
@dataclass | ||||||||
class LstmBlockV1Config(ModelConfiguration): | ||||||||
input_dim: int | ||||||||
hidden_dim: int | ||||||||
num_layers: int | ||||||||
bias: bool | ||||||||
dropout: float | ||||||||
bidirectional: bool | ||||||||
enforce_sorted: bool | ||||||||
|
||||||||
@classmethod | ||||||||
def from_dict(cls, model_cfg_dict: Dict): | ||||||||
christophmluscher marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
model_cfg_dict = model_cfg_dict.copy() | ||||||||
return cls(**model_cfg_dict) | ||||||||
|
||||||||
|
||||||||
class LstmBlockV1(nn.Module): | ||||||||
def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs): | ||||||||
christophmluscher marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
super().__init__() | ||||||||
|
||||||||
self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg | ||||||||
christophmluscher marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
self.dropout = self.cfg.dropout | ||||||||
self.enforce_sorted = None | ||||||||
self.lstm_stack = nn.LSTM( | ||||||||
input_size=self.cfg.input_dim, | ||||||||
hidden_size=self.cfg.hidden_dim, | ||||||||
num_layers=self.cfg.num_layers, | ||||||||
bias=self.cfg.bias, | ||||||||
dropout=self.dropout, | ||||||||
batch_first=True, | ||||||||
bidirectional=self.cfg.bidirectional, | ||||||||
) | ||||||||
|
||||||||
def forward(self, x: torch.Tensor, seq_len: torch.Tensor) -> (torch.Tensor, torch.Tensor): | ||||||||
christophmluscher marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why only when not scripting? Don't you want that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I followed the example in the blstm part.
I did not copy the comment over... I did not yet get to look why this is necessary |
||||||||
if seq_len.get_device() >= 0: | ||||||||
seq_len = seq_len.cpu() | ||||||||
Comment on lines
+53
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
lstm_packed_in = nn.utils.rnn.pack_padded_sequence( | ||||||||
input=x, | ||||||||
lengths=seq_len, | ||||||||
enforce_sorted=self.enforce_sorted, | ||||||||
batch_first=True, | ||||||||
) | ||||||||
|
||||||||
lstm_out, _ = self.lstm_stack(lstm_packed_in) | ||||||||
lstm_out, _ = nn.utils.rnn.pad_packed_sequence( | ||||||||
lstm_out, | ||||||||
padding_value=0.0, | ||||||||
batch_first=True, | ||||||||
) | ||||||||
|
||||||||
return lstm_out, seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same Q as in the other PR: why is this necessary now, and hasn't been for the other assemblies?