Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLM.int8() Refactoring: Part 1 #1401

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open

LLM.int8() Refactoring: Part 1 #1401

wants to merge 59 commits into from

Conversation

matthewdouglas
Copy link
Member

@matthewdouglas matthewdouglas commented Oct 24, 2024

This PR is the initial phase of a set of changes aimed at improving the LLM.int8() implementation.

Still in draft at the moment, but since there's a lot here I'm ready to have eyes on it.
@TimDettmers @Titus-von-Koeller

Primary Purpose

Enhancements

  • Removes the usage of Turing and Ampere specific memory layouts while retaining compatibility across sm_75 through sm_89.
    • Simplifies the code and surface area needing to be maintained.
    • Reduced overhead by removing layout transformation operations.
  • Removes the separate NO_CUBLASLT build while retaining compatibility for targets below sm_75. verification in progress
    • This simplifies building and packaging, and trims the size of binary wheels in ~half.
  • Support for CUDA Graph tracing to bring parity with 4bit.
  • Improved kernels for inference:
    • Fused kernel for activation scale calibration and quantization. (Exposed as op F.int8_vectorwise_quant)
    • Other kernels simplified to operate with row-major data.
  • Makes many unit tests more reliable with increased determinism.

Deprecations

The following functions from bitsandbytes are deprecated:

mm_cublas
bmm_cublas
matmul_cublas

The following functions from bitsandbytes.functional are deprecated:

_mul
arange
dequant_min_max
dequantize_no_absmax
extract_outliers
get_special_format_str
get_tensor_stream (moved to internal API)
get_transform_buffer
get_transform_func
mm_dequant (replacement: int8_mm_dequant)
igemmlt (replacement: int8_linear_matmul)
nvidia_transform
post_call
pre_call
transform
quantize_no_absmax
vectorwise_dequant
vectorwise_quant (~replacement: int8_vectorwise_quant)
vectorwise_mm_dequant (~replacement: int8_mm_dequant)

Further testing and benchmarking will be coming. At the moment unit tests are passing.

Next steps

  • Clean up and reorganize unit tests
  • Documentation for public APIs
  • Ensure fallback path for shapes that don't work well with cuBLASLt (i.e. m/k not multiples of 4).
  • Add an int8 dequantize op
  • Further improvement of sparse decomposition performance (Deferred to future PRs)
  • Conduct profiling, benchmarks, and evaluations
    • Build benchmark/evaluation scripts
    • Prepare analysis of results

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import operator
from math import prod
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support Python 3.8+ only, so use the builtin.

@@ -245,11 +238,11 @@ class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None
CxB = None
CxB = None # TODO: Deprecate/remove
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be used anymore but I'm not sure of the side-effects of removing these properties either. Could be downstream integrations accessing them, maybe used in serialization etc. Any tips/thoughts here are welcome.

Comment on lines +345 to +347
# Zero out the outliers in the transposed 8bit inputs.
if CAt is not None:
CAt[:, state.idx] = 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We skip this for inference now as it's also not needed.

Comment on lines -439 to +431
if t is None:
continue # NULL pointers are fine
is_paged = getattr(t, "is_paged", False)
on_gpu &= t.device.type == "cuda" or is_paged
if not is_paged:
# NULL pointers and paged tensors are OK.
if t is not None and not getattr(t, "is_paged", False):
on_gpu &= t.is_cuda
gpu_ids.add(t.device.index)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't specific for int8, but while I was profiling I noticed an opportunity to slightly improve some of the overhead here.

csrc/kernels.cu Outdated Show resolved Hide resolved
Comment on lines -3528 to +3574
for(int i = threadIdx.x; i < 16; i++)
quant_map[i] = T(datatype[i]);
if (threadIdx.x < 16)
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
//for(int i = threadIdx.x; i < 16; i++)
//quant_map[i] = T(__ldg(&datatype[i]));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not int8 but another small 4bit change that wanted to sneak its way in. @TimDettmers I'm just looking for a sanity check here that this makes sense.

Copy link
Contributor

@akx akx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple more comments, mostly regarding documenting the newly public-and-recommended functions.

There's also a bunch of commented-out code that probably should get removed, or it'll haunt the source forever otherwise, I think...

bitsandbytes/functional.py Outdated Show resolved Hide resolved
bitsandbytes/functional.py Show resolved Hide resolved
Comment on lines +2318 to +2321
if SA is not None and SA[1] != "row":
raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`")
if SB is not None and SB[1] != "row":
raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the type annotations, SA and SB can't be None to begin with.

However, they're not even used in the inner int8_linear_matmul() call...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those arguments were used in the old version, but serve no purpose now. So the main point here is to advise users that happen to try to invoke the old function with col32/ampere/turing layouts that it's not supported anymore. If you invoke with a row-major layout, it's still supported.

SA and SB were required; I'm just being defensive here with checking that they actually are not None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it'd be better to be even more defensive: just fail altogether if these arguments are passed, as they won't be used/assigned to/...?

Suggested change
if SA is not None and SA[1] != "row":
raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`")
if SB is not None and SB[1] != "row":
raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`")
if SA is not None:
raise NotImplementedError("SA will not be used; please pass None")
if SB is not None:
raise NotImplementedError("SB will not be used; please pass None")

or similar..?

bitsandbytes/functional.py Show resolved Hide resolved
bitsandbytes/functional.py Outdated Show resolved Hide resolved
bitsandbytes/functional.py Show resolved Hide resolved
bitsandbytes/functional.py Show resolved Hide resolved
bitsandbytes/functional.py Outdated Show resolved Hide resolved
bitsandbytes/functional.py Show resolved Hide resolved
csrc/kernels.cu Show resolved Hide resolved
@Titus-von-Koeller
Copy link
Collaborator

Hey @matthewdouglas,

Thanks again for the insightful two-hour pairing session – it was great to walk through your code together. I’m impressed by your thoughtful review and the careful attention to detail in this work. There was a lot of complexity to handle pragmatically and I love the incremental refactoring approach that you took. The performance improvements are also really impressive. Great work!


Here's my feedback that I collected during our talk:

  1. Organize Test Scripts
    Consider moving scripty test parts under bitsandbytes/scripts/8-bit. Adding a reference to that in the main implementation would help guide developers to these “eval” scripts for future refactoring.

  2. Clarify absmax Logic
    In get_row_absmax, please add an explanation about why absmax only over rows is sufficient.

  3. Commentary in MatMul8bitLt
    You mentioned needing a comment in MatMul8bitLt – could you clarify the specific addition required here?

  4. Documenting Public Functions
    Ensure all public functions have clear, detailed docstrings and verify their proper rendering in the documentation.

  5. Deterministic Test Inputs
    It makes a lot of sense hard-coded test inputs to improve consistency over the prior approach of using randomization. Please make sure that this is true for all 8-bit related tests before concluding this PR. However, a follow-up PR applying this to other tests would help address ongoing flakiness and would be highly appreciated.

  6. Profiling Code Placement
    Please commit your profiling code to the main repo in a reasonable location and/or move more experimental/supplementary code to the workbench repo for future team reference.

  7. Benchmark Transparency for Users
    Adding benchmark results to the documentation would greatly benefit users, especially in a “deep-dive” section. Please clearly highlight performance comparisons with 16-bit, underscoring benefits with large context and batch sizes, where overhead remains constant. H100 benchmarks could add value but might be low priority. Focus on takeaways from performance, giving users accessible insights from your mental model, so they “know what we know”.

  8. Publicity-Worthy Performance Metrics
    Do we have any benchmark metrics from this refactor that might serve as release highlights?


Big thanks also to @akx for making time to review this work! We really appreciate your proactive contributions and helpful insights 🤗 Thanks ❤️

@matthewdouglas matthewdouglas marked this pull request as ready for review November 25, 2024 17:37
@matthewdouglas
Copy link
Member Author

There have now been some documentation updates both for the inline docstrings and the markdown-format public docs.

Additionally, tests related to 8bit now use static shapes. Certain tests related to benchmarking have been extracted away, and others have had a new deprecated marker applied where appropriate.

A more detailed look at benchmarking data will be provided with release materials. For now, an overview of inference benchmark results:

  • INT8
    • On T4 and 4090, the per-token throughput is improved by 60-85% and per-token latency is decreased by 40-45%.
    • H100 is now supported. With Llama 3.1 70B and batch size >= 8, INT8 is consistently faster than NF4.
  • NF4:
    • On T4 and 4090, with batch size of 1, per-token throughput is improved by 10-25% and per-token latency is decreased by 10-20%.
    • On H100, across all batch sizes, per-token throughput is improved by up to 28% and per-token latency is decreased by up to 22%.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request Medium risk Risk of bugs in transformers and other libraries
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] Exception: cublasLt ran into an error! during fine-tuning LLM in 8bit mode
3 participants