Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel][Triton] Add Triton implementation for scaled_mm_triton to support fp8 and int8 SmoothQuant, symmetric case #9857

Merged
merged 14 commits into from
Nov 9, 2024

Conversation

rasmith
Copy link
Contributor

@rasmith rasmith commented Oct 30, 2024

This PR adds support for running SmoothQuant models that only use the symmetric case, no zero-point adjustment. This does so by adding scaled_mm_triton() kernel and checking is_hip() in cutlass_scaled_mm().

Now, currently works with Phi-3-medium-128k-instruct-quantized.w8a8 .

Will continue to test and benchmark to add more support for int8 SmoothQuant.
FIX #xxxx (link existing issues this PR will resolve)

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Adding or changing kernels

Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

  • Make sure custom ops are registered following PyTorch guidelines: Custom C++ and CUDA Operators and The Custom Operators Manual
  • Custom operations that return Tensors require meta-functions. Meta-functions should be implemented and registered in python so that dynamic dims can be handled automatically. See above documents for a description of meta-functions.
  • Use torch.libary.opcheck() to test the function registration and meta-function for any registered ops. See tests/kernels for examples.
  • When changing the C++ signature of an existing op, the schema must be updated to reflect the changes.
  • If a new custom type is needed, see the following document: Custom Class Support in PT2.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@rasmith
Copy link
Contributor Author

rasmith commented Oct 30, 2024

/ready

if use_scalar_scale_b:
scale_b = torch.rand((1, 1), device=device)
else:
scale_b = 0.25 * torch.rand((1, 1), device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be?

Suggested change
scale_b = 0.25 * torch.rand((1, 1), device=device)
scale_b = 0.25 * torch.rand((1, N), device=device)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks for catching that! I retested and everything still works!

Comment on lines 8 to 13
# This function handles some cases that can cause certain failure, e.g.
# a tensor that has shape = (72, 48) but stride = (5120, 1). It can happen,
# for example by saving a tensor using torch.save() and then adjusting its
# size afterwards and then trying to use it. Unfortunately,
# torch.is_contiguous() doesn't help since a transposed tensor doesn't return
# True, even though it can be stored contiguously in memory.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a fundamental reason we can't handle this case? This case can arise when working with slices of a tensor, and I went out of my way to support it for the cutlass_scaled_mm kernels. Supporting this is definitely a Nice To Have rather than a requirement but would like to know what the problem is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it might have been occurring in vllm, but I haven't seen it happen. I did it to myself when I was debugging though. I did a torch.save() when I found a tensor that didn't have the correct output, but the tensor was huge, so I changed the size of the tensor to (72, 48) while the old size was something like (131072, 5120). However, the strides remained (5120, 1) and my kernel got the incorrect result, while torch.__int_mm() would get the correct result for torch._int_mm(a,b). I was wondering why that was, so I looked at the pytorch code to figure it out and found out what they were doing. I put it there, just in case, but I actually haven't seen this happen in vLLM though. I thought about removing the prepare_matrix_for_triton() function, but this scaled_mm_triton() function is general enough that it could be used elsewhere (it handles any input dtype and output dtype I've tried so far and can mix and match per-tensor, row-wise, and column-wise scaling), so I left it in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try taking this out and test, and put an assert where it was used instead, per suggestion on the other comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just using assert now.


# input - [M, K]
# weight - [K, N]
def scaled_mm_triton(input: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe call it triton_scaled_mm just for consistency since we already have cutlass_scaled_mm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

block_size_sb = 1 if has_scalar(scale_b) else block_size_n

input = prepare_matrix_for_triton(input)
weight = prepare_matrix_for_triton(weight)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the weights need to be preprocessed, it's be better to do this during process_weights_after_loading, rather than during the forward pass as this would be very slow.

I suggest replacing this with an assert or adding a warning so it's obvious that there's a problem rather than silently having a performance regression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just using assert now.

[M, 1])
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
[N, 1])
assert torch.empty((1, 1), dtype=out_dtype).is_floating_point()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert torch.empty((1, 1), dtype=out_dtype).is_floating_point()
assert out_dtype.is_floating_point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@pytest.mark.parametrize("N", [2048, 8192, 16384, 256, 1024])
@pytest.mark.parametrize("K", [128, 496, 1024])
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("in_dtype", [torch.int8])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does scaled_mm_triton support fp8_e4m3_fnuz?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, scaled_mm_triton supports pretty much everything, as long as the hardware and Triton supports it. It's just that the tests will take super long to run if I run it on all types. The fp8 types aren't supported uniformly, e.g. e5m2 seems to only have some limited support on AMD hardware right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, could you expand the test coverage to cover fp8? vLLM doesn't use e5m2 for linear layers currently, but will use fp8_e4m3 extensively. Ideally we would detect if we're on a CUDA vs RoCM system and test using fp8_e4m3_fn or fp8_e4m3_fnuz accordingly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing covering fp8 now.

Comment on lines 80 to 83
offsets_scale_a = (offsets_scale_am[:, None].to(tl.int64) +
tl.arange(0, 1)[None, :].to(tl.int64))
offsets_scale_b = (offsets_scale_bn[:, None].to(tl.int64) +
tl.arange(0, 1)[None, :].to(tl.int64))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what's going on here? It seems like this could be simplified and left as 1D with broadcasting happening during the load.

Also I would think you could leave these as int32 since they are vectors

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be able to simplify this, going to give it a shot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified.

vllm/_custom_ops.py Outdated Show resolved Hide resolved
@tlrmchlsmth
Copy link
Collaborator

Thanks for the contribution! Will be nice to have RoCM support here, and it's nice to have triton implementations for these in general. Overall looks good, left a few in-line comments

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Randall Smith <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How long do the test_scaled_mm_triton.py tests take on the AMD GPUs you've tested on? They are taking quite a while, like 10s of minutes on an H100 and we have been trying to keep the kernel unit test time in check

tests/kernels/test_scaled_mm_triton.py Outdated Show resolved Hide resolved
tests/kernels/test_scaled_mm_triton.py Outdated Show resolved Hide resolved
@tlrmchlsmth tlrmchlsmth changed the title [Kernel][Triton] Add Triton implementation for scaled_mm_triton to support int8 SmoothQuant, symmetric case [Kernel][Triton] Add Triton implementation for scaled_mm_triton to support fp8 and int8 SmoothQuant, symmetric case Nov 6, 2024
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran the tests on an L40 as well, looks good.

I'm accepting but please fix the issue related to skipping tests that I spotted

Comment on lines 49 to 50
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should detect if in_dtype is fp8 as well, otherwise we're skipping the int8 tests when we don't need to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth I just changed get_8bit_types() to only include fp8 if current_platform.has_device_capability(89) returns true. Will this work for you?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 7, 2024
Comment on lines 34 to 35
supports_fp8 = current_platform.has_device_capability(89)
if current_platform.is_rocm() and supports_fp8:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right for rocm though? What does current_platform.has_device_capability(89) return for AMD systems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it returns true on MI300.

@rasmith
Copy link
Contributor Author

rasmith commented Nov 8, 2024

@tlrmchlsmth Merge?

@tlrmchlsmth tlrmchlsmth merged commit 127c074 into vllm-project:main Nov 9, 2024
54 checks passed
@tlrmchlsmth
Copy link
Collaborator

@rasmith merged -- thanks for the ping!

omer-dayan pushed a commit to omer-dayan/vllm that referenced this pull request Nov 10, 2024
…pport fp8 and int8 SmoothQuant, symmetric case (vllm-project#9857)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: OmerD <[email protected]>
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
…pport fp8 and int8 SmoothQuant, symmetric case (vllm-project#9857)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
jeejeelee pushed a commit to jeejeelee/vllm that referenced this pull request Nov 11, 2024
…pport fp8 and int8 SmoothQuant, symmetric case (vllm-project#9857)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Jee Jee Li <[email protected]>
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 13, 2024
…pport fp8 and int8 SmoothQuant, symmetric case (vllm-project#9857)

Signed-off-by: Randall Smith <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…pport fp8 and int8 SmoothQuant, symmetric case (vllm-project#9857)

Signed-off-by: Randall Smith <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants