From 6bf9e2e8345732ba0b12d9ef2bb633d598ea6917 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20M=2E=20L=C3=BCscher?= Date: Thu, 19 Dec 2024 12:30:34 +0100 Subject: [PATCH] add import, set var correctly, add doc --- i6_models/parts/lstm.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/lstm.py b/i6_models/parts/lstm.py index 898c075..b1109fe 100644 --- a/i6_models/parts/lstm.py +++ b/i6_models/parts/lstm.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import torch from torch import nn -from typing import Dict, Union +from typing import Dict, Tuple, Union from i6_models.config import ModelConfiguration @@ -26,12 +26,18 @@ def from_dict(cls, model_cfg_dict: Dict): class LstmBlockV1(nn.Module): def __init__(self, model_cfg: Union[LstmBlockV1Config, Dict], **kwargs): + """ + Model definition of LSTM block. Contains single lstm stack and padding sequence in forward call. + + :param model_cfg: holds model configuration as dataclass or dict instance. + :param kwargs: + """ super().__init__() self.cfg = LstmBlockV1Config.from_dict(model_cfg) if isinstance(model_cfg, Dict) else model_cfg self.dropout = self.cfg.dropout - self.enforce_sorted = None + self.enforce_sorted = self.cgf.enforce_sorted self.lstm_stack = nn.LSTM( input_size=self.cfg.input_dim, hidden_size=self.cfg.hidden_dim,