Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Fix compute datatype for cutlass 3.x epilogues (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent 6664f2a commit 42cdb40
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 59 deletions.
4 changes: 2 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ struct ScaledEpilogueBias
using ScaleB = typename SUPER::ScaleB;

using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, ElementD,
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;

using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, ElementD,
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using BiasDescriptor =
Expand Down
125 changes: 68 additions & 57 deletions tests/kernels/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Type
from typing import Optional, Type

import pytest
import torch
Expand Down Expand Up @@ -32,12 +32,27 @@ def to_int8(tensor: torch.Tensor):
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

output = (scale_a * (scale_b * (torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
if bias is not None:
output = output + bias

return output


def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
Expand All @@ -48,31 +63,27 @@ def cutlass_fp8_gemm_helper(m: int,
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1

scale_a = (torch.randn(
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
bias = None

baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(out_dtype)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)


def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
bias: bool,
use_bias: bool,
out_dtype: Type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
# Test for a cutlass kernel with per-token activation quantization
Expand All @@ -83,22 +94,19 @@ def cutlass_int8_gemm_helper(m: int,
m_a_scales = m if per_token_act_quant else 1
n_b_scales = n if per_out_channel_weight_quant else 1

scale_a = (torch.randn(
(m_a_scales, 1), device=device, dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device=device, dtype=torch.float32) / 10)
scale_a = (torch.randn((m_a_scales, 1), device=device,
dtype=torch.float32))
scale_b = (torch.randn((1, n_b_scales), device=device,
dtype=torch.float32))

if bias:
# bias term should be > 1 so that the absolute tolerance can catch it
bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
else:
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
bias_t = 0
bias = None

out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)) +
bias_t).to(dtype=out_dtype)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)


Expand All @@ -107,7 +115,7 @@ def cutlass_int8_gemm_helper(m: int,
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
# UPSTREAM SYNC: This is currently 90, because we need CUDA 12.4
# to use the cutlass fp8 kernels + we do not have this in our
# automation system yet.
Expand All @@ -116,41 +124,41 @@ def cutlass_int8_gemm_helper(m: int,
"type because we need CUDA 12.4 + we do "
"not have this in automation yet.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
per_out_ch: bool, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)


@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
per_out_ch: bool, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype],
bias: bool):
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
use_bias,
out_dtype=out_dtype)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
# UPSTREAM SYNC: This is currently 90, because we need CUDA 12.4
# to use the cutlass fp8 kernels + we do not have this in our
# automation system yet.
Expand All @@ -160,19 +168,19 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
"not have this in automation yet.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype: Type[torch.dtype],
bias: bool):
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
use_bias,
out_dtype=out_dtype)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
# UPSTREAM SYNC: This is currently 90, because we need CUDA 12.4
# to use the cutlass fp8 kernels + we do not have this in our
Expand All @@ -182,23 +190,23 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
"type because we need CUDA 12.4 + we do "
"not have this in automation yet.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
use_bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
torch.bfloat16, device)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
bias: bool, device: str):
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
per_act_token,
per_out_ch,
bias,
use_bias,
out_dtype=torch.bfloat16,
device=device)

Expand All @@ -210,7 +218,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
# UPSTREAM SYNC: This is currently 90, because we need CUDA 12.4
# to use the cutlass fp8 kernels + we do not have this in our
# automation system yet.
Expand All @@ -219,21 +227,22 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
"type because we need CUDA 12.4 + we do "
"not have this in automation yet.")
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
use_bias)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
bias: bool):
use_bias: bool):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
bias)
use_bias)


# Test working with a subset of A and B
Expand All @@ -254,9 +263,11 @@ def test_cutlass_subset():
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b *
b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)

assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)

Expand Down

0 comments on commit 42cdb40

Please sign in to comment.