From bb8fa4e690117bae6ab7694b908ee3366376c54b Mon Sep 17 00:00:00 2001 From: Nick Rossenbach Date: Thu, 24 Oct 2024 16:29:56 +0200 Subject: [PATCH] Disable autocast for ZoneoutLSTM cell --- i6_models/decoder/zoneout_lstm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/i6_models/decoder/zoneout_lstm.py b/i6_models/decoder/zoneout_lstm.py index fef76390..d23ba8d0 100644 --- a/i6_models/decoder/zoneout_lstm.py +++ b/i6_models/decoder/zoneout_lstm.py @@ -24,7 +24,8 @@ def __init__(self, cell: nn.RNNCellBase, zoneout_h: float, zoneout_c: float): def forward( self, inputs: torch.Tensor, state: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: - h, c = self.cell(inputs) + with torch.autocast(device_type="cuda", enabled=False): + h, c = self.cell(inputs) prev_h, prev_c = state h = self._zoneout(prev_h, h, self.zoneout_h) c = self._zoneout(prev_c, c, self.zoneout_c)