Skip to content

Commit

Permalink
Fix restoration of quant_storage for CPU offloading (#1279)
Browse files Browse the repository at this point in the history
* Fix restoration of quant_storage for CPU offloading

* Clarify comment on default quant_storage in Params4bit.from_prequantized()

* fix to make quant_storage dynamic based on serialized dtype

* delete obsolete comment

---------

Co-authored-by: Titus von Koeller <[email protected]>
  • Loading branch information
matthewdouglas and Titus-von-Koeller authored Jul 23, 2024
1 parent e3ae243 commit 7fed393
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,13 @@ def from_prequantized(
self.compress_statistics = self.quant_state.nested
self.quant_type = self.quant_state.quant_type
self.bnb_quantized = True

self.quant_storage = data.dtype

return self

def _quantize(self, device):
w = self.data.contiguous().cuda(device)
w = self.data.contiguous().to(device)
w_4bit, quant_state = bnb.functional.quantize_4bit(
w,
blocksize=self.blocksize,
Expand Down Expand Up @@ -333,6 +336,7 @@ def to(self, *args, **kwargs):
blocksize=self.blocksize,
compress_statistics=self.compress_statistics,
quant_type=self.quant_type,
quant_storage=self.quant_storage,
)

return new_param
Expand Down Expand Up @@ -450,7 +454,7 @@ def forward(self, x: torch.Tensor):
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
self.weight.quant_state = self.quant_state
else:
print(
Expand Down

0 comments on commit 7fed393

Please sign in to comment.