Skip to content

Commit

Permalink
Disable autocast for ZoneoutLSTM cell
Browse files Browse the repository at this point in the history
  • Loading branch information
JackTemaki committed Oct 24, 2024
1 parent a988e85 commit bb8fa4e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion i6_models/decoder/zoneout_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float):
def forward(
self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
h, c = self.cell(inputs)
with torch.autocast(device_type="cuda", enabled=False):
h, c = self.cell(inputs)
prev_h, prev_c = state
h = self._zoneout(prev_h, h, self.zoneout_h)
c = self._zoneout(prev_c, c, self.zoneout_c)
Expand Down

0 comments on commit bb8fa4e

Please sign in to comment.