diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 4c474c4a..d0a75244 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -103,7 +103,7 @@ def __init__( self.l2_alpha = l2_alpha self.ms_dssim_alpha = ms_dssim_alpha - @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + @torch.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, preds, target): loss = 0 if self.l1_alpha: