diff --git a/cflearn/modules/extractors/rnn/core.py b/cflearn/modules/extractors/rnn/core.py index 6ae744c4c..e4bb638bc 100644 --- a/cflearn/modules/extractors/rnn/core.py +++ b/cflearn/modules/extractors/rnn/core.py @@ -2,6 +2,7 @@ from typing import Any from typing import Dict +from torch.nn import init from ..base import ExtractorBase from ...transform.core import Dimensions @@ -26,9 +27,17 @@ def __init__( input_dimensions = [self.in_dim] self.hidden_size = cell_config["hidden_size"] input_dimensions += [self.hidden_size] * (num_layers - 1) - self.rnn_list = torch.nn.ModuleList( - [rnn_base(dim, **cell_config) for dim in input_dimensions] - ) + rnn_list = [] + for dim in input_dimensions: + rnn = rnn_base(dim, **cell_config) + with torch.no_grad(): + for name, param in rnn.named_parameters(): + if "weight" in name: + init.orthogonal_(param) + elif "bias" in name: + init.zeros_(param) + rnn_list.append(rnn) + self.rnn_list = torch.nn.ModuleList(rnn_list) @property def flatten_ts(self) -> bool: