Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Tuned int8 kernels for Ada Lovelace #6848

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"

/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
Expand Down Expand Up @@ -98,39 +99,31 @@ template <template <typename, typename> typename Epilogue,
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
int8_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
int8_t, cutlass::half_t, Epilogue, TileShape,
WarpShape, InstructionShape, 5>>(
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm89_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
#include "cutlass/float8.h"

/**
* This file defines Gemm kernel configurations for SM89 based on the Gemm
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/

namespace vllm {

template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_fallback_gemm {
struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
Expand All @@ -25,7 +25,7 @@ struct sm89_fallback_gemm {
FP8MathOperator>;
};

struct sm89_config_default {
struct sm89_fp8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -40,7 +40,8 @@ struct sm89_config_default {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -74,7 +75,7 @@ struct sm89_config_default {
}
};

struct sm89_config_M256 {
struct sm89_fp8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -89,7 +90,8 @@ struct sm89_config_M256 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand All @@ -114,7 +116,7 @@ struct sm89_config_M256 {
}
};

struct sm89_config_M128 {
struct sm89_fp8_config_M128 {
// M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -129,7 +131,8 @@ struct sm89_config_M128 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -163,7 +166,7 @@ struct sm89_config_M128 {
}
};

struct sm89_config_M64 {
struct sm89_fp8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

Expand All @@ -176,7 +179,8 @@ struct sm89_config_M64 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -215,7 +219,7 @@ struct sm89_config_M64 {
}
};

struct sm89_config_M32 {
struct sm89_fp8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
Expand All @@ -229,7 +233,8 @@ struct sm89_config_M32 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -265,7 +270,7 @@ struct sm89_config_M32 {
}
};

struct sm89_config_M16 {
struct sm89_fp8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
Expand All @@ -281,7 +286,8 @@ struct sm89_config_M16 {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);

using FallbackGemm =
typename sm89_fallback_gemm<InType, OutType, Epilogue>::Cutlass2xGemm;
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;

uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
Expand Down Expand Up @@ -320,10 +326,10 @@ struct sm89_config_M16 {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
Expand All @@ -334,27 +340,27 @@ inline void cutlass_gemm_sm89_dispatch(torch::Tensor& out,

if (mp2 <= 16) {
// M in [1, 16]
return sm89_config_M16::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_config_M32::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_config_M64::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_config_M128::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_config_M256::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_config_default::dispatch<InType, OutType, Epilogue>(
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
Expand Down
Loading
Loading