diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 82b8b6432b..9e830d3cfb 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -494,8 +494,18 @@ def reset_parameters(self): def forward(self, x: torch.Tensor): result = super().forward(x) + if self.disable_adapters: return result elif self.r > 0: - result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + if not torch.is_autocast_enabled(): + expected_dtype = result.dtype + + if x.dtype != torch.float32: + x = x.float() + output = self.lora_B(self.lora_A(self.lora_dropout(x))).to(expected_dtype) * self.scaling + result += output + else: + output = self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling + result += output return result