From 33d65b2bb7db5ea8847fcf2212be69b487fdb984 Mon Sep 17 00:00:00 2001 From: younesbelakda Date: Wed, 22 Feb 2023 15:27:16 +0000 Subject: [PATCH] fix autocast issue --- src/peft/tuners/lora.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 82b8b6432b..9e830d3cfb 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -494,8 +494,18 @@ def reset_parameters(self): def forward(self, x: torch.Tensor): result = super().forward(x) + if self.disable_adapters: return result elif self.r > 0: - result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + + if x.dtype != torch.float32: + x = x.float() + output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling + result += output + else: + output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + result += output return result