Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Initial Machete W4A8 support + Refactors #9855

Merged
merged 9 commits into from
Nov 18, 2024

Conversation

LucasWilkinson
Copy link
Contributor

@LucasWilkinson LucasWilkinson commented Oct 30, 2024

  • add machete kernel support for QQQ style w4a8 quantization (including fp8 activations)
  • refactor machete dispatching logic
  • refactor machete file generation
  • reduce the number of prepack kernels generated (e.g. don't need separate kernels for uint4b8 and uint4 since we are just shuffling data around)

TODO (Future PR):

  • end2end integration
  • perf improvements (mostly hoping to land now for the refactor)

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson changed the title [WIP, Kernel] (3/N) Machete W4A8 (signed) [WIP, Kernel] (3/N) Machete W4A8 Oct 30, 2024
@LucasWilkinson LucasWilkinson marked this pull request as ready for review November 1, 2024 03:17
@LucasWilkinson LucasWilkinson changed the title [WIP, Kernel] (3/N) Machete W4A8 [Kernel] (3/N) Machete W4A8 Nov 1, 2024
@LucasWilkinson LucasWilkinson changed the title [Kernel] (3/N) Machete W4A8 [Kernel] (3/N) Initial Machete W4A8 support + Refactors Nov 1, 2024
@LucasWilkinson LucasWilkinson changed the title [Kernel] (3/N) Initial Machete W4A8 support + Refactors [Kernel] Initial Machete W4A8 support + Refactors Nov 1, 2024
Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 minor comments, looks good otherwise

Copy link

mergify bot commented Nov 6, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @LucasWilkinson please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 6, 2024
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/machete-w4a8-signed branch 2 times, most recently from 3555a56 to 5c09f95 Compare November 6, 2024 15:49
@mergify mergify bot removed the needs-rebase label Nov 6, 2024
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/machete-w4a8-signed branch 2 times, most recently from 565770c to 630c540 Compare November 6, 2024 20:05
([16384, 16384], 0),
([16384, 106496], 1),
([53248, 16384], 0),
],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : for big models, I have found it useful to have their realistic TPn counter-parts also (e.g. for the 70B case, add a 70B-TP4 case). That way we can just list that version in the 1GPU model benchmarking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean so you can list it as a string as opposed to using the --tp-sizes args?

vllm/_custom_ops.py Outdated Show resolved Hide resolved
benchmarks/kernels/benchmark_machete.py Show resolved Hide resolved
r;
uint32_t src = src_[0];
// Determines if to get from the signed or unsigned candidates
uint32_t sign = (src & 0x88888888) >> 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the right shift for?

Copy link
Contributor Author

@LucasWilkinson LucasWilkinson Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the comment to be more verbose reads as:

      // Determines if to get from the signed or unsigned candidates
      //  move into bit position 0x4 of each nibble so when or'd with
      //  final_prmt_base it selects the correct candidate, when elements
      //  in final_prmt_base are >= 0x4, the negative candidate is selected
      //  (i.e. from NEG_INT8_REG{1}{2}), when elements are < 0x4, the positive
      //  candidate is selected (i.e. from POS_INT8_REG{1}{2})
      uint32_t sign = (src & 0x88888888) >> 1;

      // `sign` is OR'd with 0x31203120 to find the correct value in the LUT
      // (selects correct positive or negative candidate)
      const uint32_t final_prmt_base = 0x32103210;

      // Ignore sign bit when indexing into LUT, for each 4bit value
      //  we index into both the positive and negative candidates then use
      //  sign | final_prmt_base to select the correct candidate
      uint32_t lut_idx = (src & 0x77777777);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm refactored to use lut_4bit_to_8bit_convert utility function, now the comment looks like

  // Determines if the value is in the top half of the LUT if set or
  //  (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
  //  into bit position 0x4 of each nibble so when or'd with final_prmt_base it
  //  selects the correct candidate. When elements in final_prmt_base
  //  are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
  //  are  < 0x4, the low candidate is selected (i.e. LUT[0:7])
  uint32_t high_bit = (src & 0x88888888) >> 1;

  // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
  // (selects correct high or low candidate)
  const uint32_t final_prmt_base = 0x32103210;

  // Ignore the high bit when indexing into LUT, for each 4bit value
  //  we index into both the high and low candidates then use
  //  high_bit | final_prmt_base to select the correct candidate
  uint32_t lut_idx = (src & 0x77777777);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

csrc/cutlass_extensions/vllm_numeric_conversion.cuh Outdated Show resolved Hide resolved
Comment on lines 246 to 249
static constexpr uint32_t POS_E4M3s_REG1 = 0x44403800; // [0, 1, 2, 3]
static constexpr uint32_t POS_E4M3s_REG2 = 0x4E4C4A48; // [4, 5, 6, 7]
static constexpr uint32_t NEG_E4M3s_REG1 = 0xCACCCED0; // [-8,-7,-6,-5]
static constexpr uint32_t NEG_E4M3s_REG2 = 0xB8C0C4C8; // [-4,-3,-2,-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the int4 -> fp8 and int4 -> int8 converters the same except for these constants? If so, might be nice to factor these out. Not a big deal though, because it'd be kind of annoying to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, bit janky but refactored to lut_4bit_to_8bit_convert

Comment on lines +174 to +176
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy(
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{})));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change?

Copy link
Contributor Author

@LucasWilkinson LucasWilkinson Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its just moved up (to be closer to SmemLayoutA) from below, where the following is deleted:

   using SmemLayoutACopy = decltype(tile_to_shape(
       SmemLayoutAtomARowMajor{},
       make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
                  Int<DispatchPolicy::Stages>{}),
       conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
                     Step<_2, _1, _3>, Step<_1, _2, _3>>{}));

the

 conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
                     Step<_2, _1, _3>, Step<_1, _2, _3>>{}))

is removed since it was just cruft from the original PR thats not actually exercised (that was my bad)

@varun-sundar-rabindranath
Copy link
Contributor

Reviewed the cutlass refactor part - LGTM!

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, LGTM!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2024
csrc/quantization/machete/generate.py Outdated Show resolved Hide resolved
tests/kernels/test_machete_gemm.py Outdated Show resolved Hide resolved
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Copy link
Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for getting it green!

@mgoin mgoin merged commit 96d999f into vllm-project:main Nov 18, 2024
71 checks passed
mikejuliet13 pushed a commit to mikejuliet13/vllm that referenced this pull request Nov 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants