Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate int4kv to benchmark_attn_decode.py #1029

Merged
merged 6 commits into from
May 8, 2024

Conversation

scxiao
Copy link
Contributor

@scxiao scxiao commented Apr 19, 2024

What does this PR do?

Integrate the int4kv standalone test to the benchmark benchmark_attn_decode.py

Before submitting

  • Did you have fun?
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot
Copy link
Contributor

Hi @scxiao!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@scxiao scxiao marked this pull request as draft April 19, 2024 13:46
@codecov-commenter
Copy link

codecov-commenter commented Apr 19, 2024

Codecov Report

Attention: Patch coverage is 0% with 1 lines in your changes are missing coverage. Please review.

Project coverage is 59.93%. Comparing base (67f38b9) to head (50d747c).

Files Patch % Lines
xformers/ops/fmha/triton_splitk.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1029   +/-   ##
=======================================
  Coverage   59.93%   59.93%           
=======================================
  Files         113      113           
  Lines       10007    10007           
=======================================
  Hits         5998     5998           
  Misses       4009     4009           
Flag Coverage Δ
Python 59.93% <0.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@scxiao scxiao marked this pull request as ready for review April 19, 2024 19:32
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 19, 2024
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@scxiao
Copy link
Contributor Author

scxiao commented Apr 24, 2024

Hi @jianyuh, could you please help to review this PR to include int4kv in the FA decode benchmark? It seems like I cannot add reviewer here. Thanks

@jianyuh jianyuh requested review from sgrigory and bottler April 25, 2024 06:12
@jianyuh
Copy link
Member

jianyuh commented Apr 25, 2024

Thanks @scxiao ! Could you fix the lint issue?

class AttentionDecodingPyTorchRepeat(AttentionDecodingBase):
def fw(self) -> None:
B, Mq, Mkv, Hq, Hkv, K = self.shapes
scale = 1 / K**0.5
q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3)
k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale
attn = (q @ k.transpose(-1, -2) * scale).softmax(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Looks like we didn't sync with https://github.com/ROCm/phantom_amd_llm_operators/pull/18/files ..

@@ -637,7 +637,7 @@ def dequantize(
x_[:, :, None, :] >> offsets
) # (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)

quant_offset = tl.reshape(
quant_offset = tl.view(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think tl.view, which we used to use here and which used to do a non obvious thing, still officially has unspecified behaviour in triton. tl.reshape actually claims to return the right thing, which is nice 😊. Are you changing this here because you observe a speedup, or because you expect one, or is this something else e.g. a fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw an error complaining no tl.reshape, and suggesting use tl.view. Let me double check.

@scxiao
Copy link
Contributor Author

scxiao commented Apr 25, 2024

Thanks @scxiao ! Could you fix the lint issue?

It seems like I did not change that file, but I fixed the format anyway. Is that OK for you? Thanks

@@ -141,14 +178,93 @@ class AttentionDecodingCKSplitKV(AttentionDecodingBase):
OP = xops.fmha.ck_splitk.FwOp


class AttentionDecodingSplitInt4KV(AttentionDecodingBase):
OP = xops.fmha.triton_splitk.FwOp
def __init__(self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have lint issues.

./xformers/benchmarks/benchmark_attn_decoding.py:183:5: E301 expected 1 blank line, found 0
    def __init__(self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool,
    ^
./xformers/benchmarks/benchmark_attn_decoding.py:1[8](https://github.com/facebookresearch/xformers/actions/runs/8839890145/job/24274714295?pr=1029#step:10:9)5:5: E124 closing bracket does not match visual indentation
    ) -> None:
    ^
./xformers/benchmarks/benchmark_attn_decoding.py:258:1: W2[9](https://github.com/facebookresearch/xformers/actions/runs/8839890145/job/24274714295?pr=1029#step:10:10)3 blank line contains whitespace
            
^

@scxiao scxiao force-pushed the int4kv_fa_decode branch from 1f30ac9 to dfd0f39 Compare April 26, 2024 19:57
@scxiao
Copy link
Contributor Author

scxiao commented Apr 29, 2024

Hi @sgrigory, could you please help review this PR when you get a chance? Thanks

@@ -569,9 +569,9 @@ def merge_attentions(
concat_path = attn_is_concat and lse_is_concat
if not concat_path:
if attn_is_concat:
attn_split = cast(torch.Tensor, attn_split).unbind(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this change fixing a failure, or was it to make the type checker happy, or something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change seems wrong to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The typing is not quite right at the moment - because unbind returns a tuple and the code which uses it needs a list I think, although it just needs a Sequence.

This change will /work/ because a Tensor is usable as a Sequence of tensors, which is what the non-concat path needs, so the unbind was unnecessary. And this change probably also fixes the type error because mypy doesn't get finer than knowing attn_split is a Union. But this change does make the code slightly less understandable.

I think OP is just trying to fix the existing lint errors in this PR: here and in the predicated...h file (which is clearly the correct fix). I'm fine if they leave them broken or just this one broken.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @bottler. Yes, the changes here are to fix lint error.

@scxiao scxiao force-pushed the int4kv_fa_decode branch from f965a72 to 50d747c Compare May 3, 2024 22:28
@scxiao
Copy link
Contributor Author

scxiao commented May 6, 2024

Hi @sgrigory, when you get a chance, could you please help review this PR, so we can get it merge soon? Thanks.
CC: @jianyuh.

@sgrigory
Copy link
Contributor

sgrigory commented May 6, 2024 via email

@bottler
Copy link
Contributor

bottler commented May 7, 2024

LGTM except the reshape->view change which I'm still uncertain about. Can we leave unchanged?

@scxiao
Copy link
Contributor Author

scxiao commented May 8, 2024

LGTM except the reshape->view change which I'm still uncertain about. Can we leave unchanged?

I just tried again, it seems like the error related to reshape is from our rocm fork. I reverted this change.

@bottler bottler merged commit 60d5f11 into facebookresearch:main May 8, 2024
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants