Skip to content

Commit

Permalink
⚡️Optimized initializations of RNN
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Mar 8, 2021
1 parent 1019114 commit 2193706
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions cflearn/modules/extractors/rnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any
from typing import Dict
from torch.nn import init

from ..base import ExtractorBase
from ...transform.core import Dimensions
Expand All @@ -26,9 +27,17 @@ def __init__(
input_dimensions = [self.in_dim]
self.hidden_size = cell_config["hidden_size"]
input_dimensions += [self.hidden_size] * (num_layers - 1)
self.rnn_list = torch.nn.ModuleList(
[rnn_base(dim, **cell_config) for dim in input_dimensions]
)
rnn_list = []
for dim in input_dimensions:
rnn = rnn_base(dim, **cell_config)
with torch.no_grad():
for name, param in rnn.named_parameters():
if "weight" in name:
init.orthogonal_(param)
elif "bias" in name:
init.zeros_(param)
rnn_list.append(rnn)
self.rnn_list = torch.nn.ModuleList(rnn_list)

@property
def flatten_ts(self) -> bool:
Expand Down

0 comments on commit 2193706

Please sign in to comment.