diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 05f7c04db..40766ad41 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -282,10 +282,13 @@ def from_prequantized( self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type self.bnb_quantized = True + + self.quant_storage = data.dtype + return self def _quantize(self, device): - w = self.data.contiguous().cuda(device) + w = self.data.contiguous().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit( w, blocksize=self.blocksize, @@ -333,6 +336,7 @@ def to(self, *args, **kwargs): blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type, + quant_storage=self.quant_storage, ) return new_param @@ -450,7 +454,7 @@ def forward(self, x: torch.Tensor): # since we registered the module, we can recover the state here assert self.weight.shape[1] == 1 if not isinstance(self.weight, Params4bit): - self.weight = Params4bit(self.weight, quant_storage=self.quant_storage) + self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True) self.weight.quant_state = self.quant_state else: print(