Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Large performance regression for FP8 E4M3 GEMM with triton==2.3 #3828

Open
mgoin opened this issue May 3, 2024 · 5 comments
Open

Large performance regression for FP8 E4M3 GEMM with triton==2.3 #3828

mgoin opened this issue May 3, 2024 · 5 comments

Comments

@mgoin
Copy link

mgoin commented May 3, 2024

There is a very large performance regression (6x slower for [8192,8192]x[8192,8192]) when using Triton for matmuls with float8 e4m3 inputs, comparing 2.2.0 and 2.3.0.

We use Triton for our fused MoE implementation in vLLM and noticed this regression while upgrading pytorch (thanks for quickly detecting @pcmoritz) from 2.2.1 -> 2.3.0, which brought about an upgrade for Triton as well (2.2.0 -> 2.3.0).

This regression seems to go away if I use the latest nightly, but we are still stuck between very poor FP8 performance with Triton and using the latest stable PyTorch (which we would like to have for FP8 GEMM support on SM89). Is it possible this could be hotfixed?

Below I share my minimal reproduction using triton.ops.matmul on an H100:

Results:

> pip install triton==2.2 numpy torch
> python benchmark_fp8.py
Benchmarking [torch.Size([8192, 8192]), fp8e4nv] x [torch.Size([8192, 8192]), fp8e4nv]
Elapsed time for 100 iterations: 0.086390 seconds

> pip install triton==2.3
> python benchmark_fp8.py
Benchmarking [torch.Size([8192, 8192]), fp8e4nv] x [torch.Size([8192, 8192]), fp8e4nv]
Elapsed time for 100 iterations: 0.547999 seconds

> pip uninstall -y triton
> pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
> python benchmark_fp8.py
Benchmarking [torch.Size([8192, 8192]), fp8e4nv] x [torch.Size([8192, 8192]), fp8e4nv]
Elapsed time for 100 iterations: 0.088446 seconds

Benchmarking script:

import triton
import triton.ops
import triton.language as tl
import torch
import time

benchmark_iters = 100

# Create input matrices
A = torch.randn(8192, 8192, dtype=torch.float16, device='cuda')
B = torch.randn(8192, 8192, dtype=torch.float16, device='cuda')

# Quantize
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn).T

# Convert to triton float8 dtype
A_fp8 = triton.reinterpret(A_fp8, tl.float8e4nv)
B_fp8 = triton.reinterpret(B_fp8, tl.float8e4nv)

print(f"Benchmarking [{A_fp8.shape}, {A_fp8.dtype}] x [{B_fp8.shape}, {B_fp8.dtype}]")

# Warm up GPU
for _ in range(10):
    c = triton.ops.matmul(A_fp8, B_fp8)
torch.cuda.synchronize()

# Timing the matmul
start_time = time.time()
for _ in range(benchmark_iters):
    c = triton.ops.matmul(A_fp8, B_fp8)
torch.cuda.synchronize()
elapsed_time = time.time() - start_time

print(f"Elapsed time for {benchmark_iters} iterations: {elapsed_time:.6f} seconds")
@mgoin mgoin changed the title Large performance regression for FP8 GEMM with triton==2.3 Large performance regression for FP8 E4M3 GEMM with triton==2.3 May 3, 2024
@atalman
Copy link
Collaborator

atalman commented May 6, 2024

cc @jansel @malfet @seemethere
This looks like H100 specific error. I am getting this issue on A100:

  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/compiler/compiler.py", line 191, in compile
    module = src.make_ir(options)
  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/compiler/compiler.py", line 117, in make_ir
    return ast_to_ttir(self.fn, self, options=options)
  File "/home/atalman/miniconda3/envs/py310/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 45:31:            a = tl.load(A)
            b = tl.load(B)
        else:
            k_remaining = K - k * (BLOCK_K * SPLIT_K)
            _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
            a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
            b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
        if AB_DTYPE:
            a = a.to(C.dtype.element_ty)
            b = b.to(C.dtype.element_ty)
        if fp8_fast_accum:
            acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32)
                               ^
AssertionError('Dot op does not support fp8e4nv on CUDA arch < 90')

@mgoin
Copy link
Author

mgoin commented May 6, 2024

Hey @atalman both triton and torch only support FP8 GEMM on GPUs with hardware support for FP8 tensor cores. So, this is intended to only work on Hopper (H100) or Ada Lovelace (L4, L40, RTX 4000 series)

@plotfi
Copy link
Contributor

plotfi commented May 7, 2024

It seems the change to maxNumImpreciseAcc from #2804 brings the run time for matmuls back to 2.2.x levels.

@ThomasRaoux
Copy link
Collaborator

It seems the change to maxNumImpreciseAcc from #2804 brings the run time for matmuls back to 2.2.x levels.

ah right this is because before that the accumulation was happening on a lower precision. To solve that you need to use the 3 source dot (acc = tl.dot(a, b, acc) instead of acc += tl.dot(a, b)) because the other representation suggests user wants a 32bits addition.

@pcmoritz
Copy link

pcmoritz commented Jun 6, 2024

Fixed in triton 2.3.1 now :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants