Skip to content

Commit

Permalink
Merge pull request #213 from tonylins/dev/fix_no_absmax
Browse files Browse the repository at this point in the history
Gix a bug in (de)quantize_no_absmax with multiple GPUs
  • Loading branch information
TimDettmers authored Apr 11, 2023
2 parents 6b4c5af + b6383ba commit c787553
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor:
Quantized 8-bit tensor.
'''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
is_on_gpu([A, out])
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out


Expand All @@ -683,9 +685,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
torch.Tensor:
32-bit output tensor.
'''
prev_device = pre_call(A.device)
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out])
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
post_call(prev_device)
return out


Expand Down

0 comments on commit c787553

Please sign in to comment.