-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Initial Machete W4A8 support + Refactors #9855
[Kernel] Initial Machete W4A8 support + Refactors #9855
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 minor comments, looks good otherwise
This pull request has merge conflicts that must be resolved before it can be |
3555a56
to
5c09f95
Compare
565770c
to
630c540
Compare
([16384, 16384], 0), | ||
([16384, 106496], 1), | ||
([53248, 16384], 0), | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit : for big models, I have found it useful to have their realistic TPn counter-parts also (e.g. for the 70B case, add a 70B-TP4 case). That way we can just list that version in the 1GPU model benchmarking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean so you can list it as a string as opposed to using the --tp-sizes
args?
r; | ||
uint32_t src = src_[0]; | ||
// Determines if to get from the signed or unsigned candidates | ||
uint32_t sign = (src & 0x88888888) >> 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the right shift for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the comment to be more verbose reads as:
// Determines if to get from the signed or unsigned candidates
// move into bit position 0x4 of each nibble so when or'd with
// final_prmt_base it selects the correct candidate, when elements
// in final_prmt_base are >= 0x4, the negative candidate is selected
// (i.e. from NEG_INT8_REG{1}{2}), when elements are < 0x4, the positive
// candidate is selected (i.e. from POS_INT8_REG{1}{2})
uint32_t sign = (src & 0x88888888) >> 1;
// `sign` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct positive or negative candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore sign bit when indexing into LUT, for each 4bit value
// we index into both the positive and negative candidates then use
// sign | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm refactored to use lut_4bit_to_8bit_convert
utility function, now the comment looks like
// Determines if the value is in the top half of the LUT if set or
// (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move
// into bit position 0x4 of each nibble so when or'd with final_prmt_base it
// selects the correct candidate. When elements in final_prmt_base
// are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements
// are < 0x4, the low candidate is selected (i.e. LUT[0:7])
uint32_t high_bit = (src & 0x88888888) >> 1;
// `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT
// (selects correct high or low candidate)
const uint32_t final_prmt_base = 0x32103210;
// Ignore the high bit when indexing into LUT, for each 4bit value
// we index into both the high and low candidates then use
// high_bit | final_prmt_base to select the correct candidate
uint32_t lut_idx = (src & 0x77777777);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!
static constexpr uint32_t POS_E4M3s_REG1 = 0x44403800; // [0, 1, 2, 3] | ||
static constexpr uint32_t POS_E4M3s_REG2 = 0x4E4C4A48; // [4, 5, 6, 7] | ||
static constexpr uint32_t NEG_E4M3s_REG1 = 0xCACCCED0; // [-8,-7,-6,-5] | ||
static constexpr uint32_t NEG_E4M3s_REG2 = 0xB8C0C4C8; // [-4,-3,-2,-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the int4 -> fp8 and int4 -> int8 converters the same except for these constants? If so, might be nice to factor these out. Not a big deal though, because it'd be kind of annoying to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, bit janky but refactored to lut_4bit_to_8bit_convert
using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( | ||
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), | ||
Int<DispatchPolicy::Stages>{}))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its just moved up (to be closer to SmemLayoutA
) from below, where the following is deleted:
using SmemLayoutACopy = decltype(tile_to_shape(
SmemLayoutAtomARowMajor{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
the
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}))
is removed since it was just cruft from the original PR thats not actually exercised (that was my bad)
Reviewed the cutlass refactor part - LGTM! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, LGTM!
479cf50
to
ef43d89
Compare
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
6af8654
to
70ad239
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks for getting it green!
Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Manjul Mohan <[email protected]>
uint4b8
anduint4
since we are just shuffling data around)TODO (Future PR):