-
Notifications
You must be signed in to change notification settings - Fork 640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable common device abstraction for 8bits/4bits #898
Changes from all commits
314d5e0
2e9550a
68fd024
ba92680
65b17a2
c5044e0
b23789a
b2a4d54
c44cf06
365491a
4050fe3
30175d1
e17549e
a53bc31
80c598c
59facc8
066d0dc
e34c30e
cebd83c
e0f2e18
d20c017
9f23308
b41c1c4
1ab611e
b933f9f
8b4baaa
0905ad7
145a835
d270832
68e7859
012b565
8fa27f6
2c04d48
c184655
03b53d7
ba7a162
d162998
2cd9718
f26a4e6
adfb5e2
6f08879
a9e4548
84f67d2
9ff6c63
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import Dict | ||
|
||
from bitsandbytes.backends.base import Backend | ||
|
||
backends: Dict[str, Backend] = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why should it be? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it intend to avoid the usage of an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does a user currently select a backend? Currently, only CUDA is supported, but should there be a function like I guess we would do this through the "device_setup" process. The question here is if we can automatically detect the device the user is running in all cases? I think the only exception is probably if a user has both an accelerated device and a CPU. I think having, for example, Apple silicon and a regular GPU will not really happen. Are there any other scenarios that we are missing here and we need to think about? I think for now, it looks fine, but I want to make sure we are not missing anything. In terms of usability, the best designs often come from early thought rather than later corrections. So it makes sense we think a bit about this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @TimDettmers To your point, I wonder if IPEX can be combined with CUDA/ROCm in such a way where as you mention, it's not clear what the user will want. E.g. a situation where both It's also my understanding that Intel GPU support may be upstreamed: pytorch/pytorch#114842 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jgong5 That's what's happening now in this PR 😁 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey all, so Tim thinks that the backend should only be initialized once and therefore implemented as Singleton: It According to him, there's no use-case to exchange the backend at runtime. The only potential use-case might be that of having both a CPU and GPU backend at the same time, but from what Tim says, this is sth that we currently don't need yet and shouldn't worry about. Just forwarding his statement for the sake of furthering the discussion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From an engineering standpoint, I disagree with implementing it as a singleton (a class you can ever only initialize once). Doing that is more complex, a little non-Pythonic, and the current implementation has the same end result: there's a backend object that's only created once, and it's plugged into place in the backends dict. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I feel the same - no obvious benefit of constraining us with a single device. May I know what's the concern with dispatching device backend from the backend dict with the device on the tensor args? Dispatching according to the tensor's device type is something PyTorch ATen is also doing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it not make sense to try to stay with the device of the source/destination tensors rather than select and initialize the device once as a singleton? If you have multiple GPUs for example and want to share the compute with them, wouldn't you want to do .to(some_device) then call BnB? |
||
|
||
|
||
def register_backend(backend_name: str, backend_instance: Backend): | ||
backends[backend_name.lower()] = backend_instance | ||
|
||
|
||
def ensure_backend_is_available(device_type: str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can have a better design here. Currently, this function is called before we call any backend function. This bloats the function calls. I think it would be better to set the backend once and do not keep a dictionary around. A singleton class would be appropriate for this and we initialize this class through the This also prevents a different problem: Currently, the dictionary is indexed via the device type (e.g. Another problem with this is that currently, paged optimizer buffers have two device types CPU and CUDA where on the PyTorch level only CPU is visible. As such, this would currently be executed on the wrong device. Although paged optimizers have CPU buffers they need to be executed on CUDA device code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also disagree somewhat with an
would do, or even If we set a backend once, we can't have a library that would support a system with, say, both a As for indexing the dictionary, I don't see that being a very pressing issue for just now. The Apple Silicon device type is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AMD (HIP) and any other GPU device will say return backends[A.device.type].dequantize_4bit(...) we add a function def get_tensor_backend(A: torch.Tensor):
if A.device.type=="cuda" and torch.version.hip:
return backends["hip"]
return backends[A.device.type] and then call return get_tensor_backend(A).dequantize_4bit(...) We could also include the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From what I understand we don't need to support multiple GPU backends at the same time, but at most CPU + GPU. I wonder what would be the best way to handle those two and initialize them only once at library initialization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is that ? Isn't multiple GPUs a valid setup? EDIT: Possibly you mean only one backend of the same type? If so then it makes more sense. Because I think AMD+CUDA or MPS+CUDA (if Apple and Nvidia ever stop fighting and there are drivers again) are valid use cases There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AMD+CUDA seems like a stretch, considering PyTorch uses the same CUDA semantics for ROCm. As far as Apple goes, one would hope support for whatever the underlying hardware is can come through Metal if future GPU options become available. In fact, as it is MTLGPUFamily.metal3 already lists out the 2019 Mac Pro (x86-64) options from AMD. My perspective is that, at least for right now, Intel GPU + (NVIDIA or AMD) is the most viable mixed-vendor GPU setup. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably. I was more making the argument that in PyTorch you can spin up any number of heterogeneous devices and shuffle tensors between them easily. |
||
"""Check if a backend is available for the given device type.""" | ||
if device_type.lower() not in backends: | ||
raise NotImplementedError(f"Device backend for {device_type} is currently not supported.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from abc import ABC, abstractmethod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the backend base contains too few functions or too many functions depending on the view. Currently, it does not provide abstractions for blockwise quantization, QLoRA-style double quantization, and 8-bit optimizers. On the other hand, This is definitely one thing that we need to discuss: what exact function do we abstract. We need to abstract everything that is needed by all devices and keep everything that is specific to CUDA in that particular backend. |
||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from bitsandbytes.utils import QuantState | ||
|
||
|
||
class Backend(ABC): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot use an abstract base class here. This makes the interface too big to implement. We want people to be able to contribute sub-interfaces, for example, only implement the 4-bit functionality but not the 8-bit and 8-bit optimizer functionality. The intent of such a design it better captured by a base class that implements these functions with an I think the intend would be even clearer by having 4 backends: 4-bit, 8-bit, 8-bit optimizers, block-wise quantization. However, this will also introduce more bloat in terms of boilerplate and more classes. Not sure how to handle this and feedback would be appreciated. I think it might be better to have a single class and just highlight both as comments and in the documentation that not all functions need to be overridden for a solid contribution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see now that the methods already throw a NotImplementedError. I think this is good already. So just removing the ABC would make it possible to implement sub-interfaces. I think to make the sub-interfaces clearer it would be great to have a NotImplementedError that shows the set of functions that need to be implemented. For example, mm_dequant(...)
...
raise NotImplementedError("mm_dequant not implemented! \
This function is part of the 8-bit interface and it needs to be implemented along with: \
mm_dequant, igemmlt, extract_outliers, double_quant") There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's okay to partially implement a backend, then sure, we can make it a concrete base class with NotImplementeds thrown around. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't partial implementations fall back to CPU? |
||
"""Base class for devices backends that will implement their own 8bits and 4bits functions.""" | ||
|
||
@abstractmethod | ||
def double_quant( | ||
self, | ||
A, | ||
col_stats=None, | ||
row_stats=None, | ||
out_col=None, | ||
out_row=None, | ||
threshold=0.0, | ||
): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def transform( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only needed for CUDA. This will probably not be needed for any other device. See the discussion above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was somewhat touched upon #898 (comment) – this PR doesn't yet move all of the CUDA-specific things into place, but I think that's fine and we can clean it up in near-future work... |
||
self, | ||
A, | ||
to_order, | ||
from_order="row", | ||
out=None, | ||
transpose=False, | ||
state=None, | ||
ld=None, | ||
): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def mm_dequant( | ||
self, | ||
A, | ||
quant_state, | ||
row_stats, | ||
col_stats, | ||
out=None, | ||
new_row_stats=None, | ||
new_col_stats=None, | ||
bias=None, | ||
): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def extract_outliers(self, A, SA, idx): | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def quantize_4bit( | ||
self, | ||
A: torch.Tensor, | ||
absmax: Optional[torch.Tensor] = None, | ||
out: Optional[torch.Tensor] = None, | ||
blocksize=64, | ||
compress_statistics=False, | ||
quant_type="fp4", | ||
quant_storage=torch.uint8, | ||
) -> Tuple[torch.Tensor, QuantState]: | ||
""" | ||
Quantize tensor A in blocks of 4-bit values. | ||
|
||
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. | ||
|
||
Parameters | ||
---------- | ||
A : torch.Tensor | ||
The input tensor. | ||
absmax : torch.Tensor | ||
The absmax values. | ||
out : torch.Tensor | ||
The output tensor. | ||
blocksize : int | ||
The blocksize used in quantization. | ||
quant_type : str | ||
The 4-bit quantization data type {fp4, nf4} | ||
|
||
Returns | ||
------- | ||
torch.Tensor: | ||
Tensor with packed 4-bit values. | ||
tuple(torch.Tensor, torch.Size, torch.dtype, int): | ||
The quantization state to undo the quantization. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def dequantize_4bit( | ||
self, | ||
A: torch.Tensor, | ||
quant_state: Optional[QuantState] = None, | ||
absmax: Optional[torch.Tensor] = None, | ||
out: Optional[torch.Tensor] = None, | ||
blocksize: int = 64, | ||
quant_type="fp4", | ||
) -> torch.Tensor: | ||
""" | ||
Dequantizes FP4 blockwise quantized values. | ||
|
||
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. | ||
|
||
Parameters | ||
---------- | ||
A : torch.Tensor | ||
The input tensor (packed 4-bit values). | ||
quant_state : QuantState | ||
object with quantisation stats, incl. absmax values, original tensor shape and original dtype. | ||
absmax : torch.Tensor | ||
The absmax values. | ||
out : torch.Tensor | ||
Dequantized output tensor. | ||
blocksize : int | ||
The blocksize used in quantization. | ||
quant_type : str | ||
The 4-bit quantization data type {fp4, nf4} | ||
|
||
|
||
Returns | ||
------- | ||
torch.Tensor: | ||
Dequantized tensor. | ||
""" | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this cause any problems? What if we have a backend that, upon initialization, makes assumptions about the hardware/system? I think this can work if the backend does not have any state. However, is it a realistic assumption if we think about other backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AIUI, this is only really here to keep things working at present and we could think about deferred initialization later.
Even in the preimage of this PR, bnb initializes a backend (the native library) at import time.