-
Notifications
You must be signed in to change notification settings - Fork 636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLM.int8() Refactoring: Part 1 #1401
base: main
Are you sure you want to change the base?
Conversation
… in new igemmlt implementation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -1,6 +1,5 @@ | |||
from dataclasses import dataclass | |||
from functools import reduce # Required in Python 3 | |||
import operator | |||
from math import prod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support Python 3.8+ only, so use the builtin.
bitsandbytes/autograd/_functions.py
Outdated
@@ -245,11 +238,11 @@ class MatmulLtState: | |||
_tile_indices: Optional[torch.Tensor] = None | |||
force_no_igemmlt: bool = False | |||
CB = None | |||
CxB = None | |||
CxB = None # TODO: Deprecate/remove |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't be used anymore but I'm not sure of the side-effects of removing these properties either. Could be downstream integrations accessing them, maybe used in serialization etc. Any tips/thoughts here are welcome.
# Zero out the outliers in the transposed 8bit inputs. | ||
if CAt is not None: | ||
CAt[:, state.idx] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We skip this for inference now as it's also not needed.
if t is None: | ||
continue # NULL pointers are fine | ||
is_paged = getattr(t, "is_paged", False) | ||
on_gpu &= t.device.type == "cuda" or is_paged | ||
if not is_paged: | ||
# NULL pointers and paged tensors are OK. | ||
if t is not None and not getattr(t, "is_paged", False): | ||
on_gpu &= t.is_cuda | ||
gpu_ids.add(t.device.index) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't specific for int8, but while I was profiling I noticed an opportunity to slightly improve some of the overhead here.
for(int i = threadIdx.x; i < 16; i++) | ||
quant_map[i] = T(datatype[i]); | ||
if (threadIdx.x < 16) | ||
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x])); | ||
//for(int i = threadIdx.x; i < 16; i++) | ||
//quant_map[i] = T(__ldg(&datatype[i])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not int8 but another small 4bit change that wanted to sneak its way in. @TimDettmers I'm just looking for a sanity check here that this makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple more comments, mostly regarding documenting the newly public-and-recommended functions.
There's also a bunch of commented-out code that probably should get removed, or it'll haunt the source forever otherwise, I think...
if SA is not None and SA[1] != "row": | ||
raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`") | ||
if SB is not None and SB[1] != "row": | ||
raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the type annotations, SA and SB can't be None to begin with.
However, they're not even used in the inner int8_linear_matmul()
call...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those arguments were used in the old version, but serve no purpose now. So the main point here is to advise users that happen to try to invoke the old function with col32/ampere/turing layouts that it's not supported anymore. If you invoke with a row-major layout, it's still supported.
SA and SB were required; I'm just being defensive here with checking that they actually are not None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it'd be better to be even more defensive: just fail altogether if these arguments are passed, as they won't be used/assigned to/...?
if SA is not None and SA[1] != "row": | |
raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`") | |
if SB is not None and SB[1] != "row": | |
raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`") | |
if SA is not None: | |
raise NotImplementedError("SA will not be used; please pass None") | |
if SB is not None: | |
raise NotImplementedError("SB will not be used; please pass None") |
or similar..?
Co-authored-by: Aarni Koskela <[email protected]>
Co-authored-by: Aarni Koskela <[email protected]>
Hey @matthewdouglas, Thanks again for the insightful two-hour pairing session – it was great to walk through your code together. I’m impressed by your thoughtful review and the careful attention to detail in this work. There was a lot of complexity to handle pragmatically and I love the incremental refactoring approach that you took. The performance improvements are also really impressive. Great work! Here's my feedback that I collected during our talk:
Big thanks also to @akx for making time to review this work! We really appreciate your proactive contributions and helpful insights 🤗 Thanks ❤️ |
Co-authored-by: Aarni Koskela <[email protected]>
There have now been some documentation updates both for the inline docstrings and the markdown-format public docs. Additionally, tests related to 8bit now use static shapes. Certain tests related to benchmarking have been extracted away, and others have had a new A more detailed look at benchmarking data will be provided with release materials. For now, an overview of inference benchmark results:
|
This PR is the initial phase of a set of changes aimed at improving the LLM.int8() implementation.
Still in draft at the moment, but since there's a lot here I'm ready to have eyes on it.
@TimDettmers @Titus-von-Koeller
Primary Purpose
Enhancements
NO_CUBLASLT
build while retaining compatibility for targets below sm_75. verification in progressF.int8_vectorwise_quant
)Deprecations
The following functions from
bitsandbytes
are deprecated:The following functions from
bitsandbytes.functional
are deprecated:Further testing and benchmarking will be coming. At the moment unit tests are passing.
Next steps
Further improvement of sparse decomposition performance(Deferred to future PRs)