Skip to content

Commit

Permalink
[Cutlass] Support alpha scaling in fp8 group gemm (apache#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Mar 19, 2024
1 parent d8dc9f2 commit 1469e5e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 25 deletions.
14 changes: 7 additions & 7 deletions src/runtime/contrib/cutlass/fp16_group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@

#include "group_gemm_runner.cuh"


#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

template <>
struct KernelTraits<cutlass::half_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

namespace tvm {
Expand All @@ -53,16 +52,17 @@ void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDAr
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
float alpha = 1.0f;
float beta = 0.0f;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
stream);
workspace->shape[0], n, k, num_groups, alpha, beta,
static_cast<ElementC*>(out->data), stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90")
.set_body_typed(
tvm_cutlass_group_gemm_sm90<cutlass::half_t, cutlass::half_t, cutlass::half_t>);
.set_body_typed(tvm_cutlass_group_gemm_sm90<cutlass::half_t, cutlass::half_t, cutlass::half_t>);

} // namespace runtime
} // namespace tvm
Expand Down
15 changes: 7 additions & 8 deletions src/runtime/contrib/cutlass/fp8_group_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,24 @@

#include "group_gemm_runner.cuh"


#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)

template <>
struct KernelTraits<cutlass::float_e4m3_t> {
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size
using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster
};

template <>
struct KernelTraits<cutlass::float_e5m2_t> : KernelTraits<cutlass::float_e4m3_t> {
};
struct KernelTraits<cutlass::float_e5m2_t> : KernelTraits<cutlass::float_e4m3_t> {};

namespace tvm {
namespace runtime {

template <typename ElementA, typename ElementB, typename ElementC>
void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray out) {
double alpha, NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
Expand All @@ -57,11 +55,12 @@ void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArr
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
float beta = 0.0f;
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
cutlass_group_gemm(static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, static_cast<ElementC*>(out->data),
stream);
workspace->shape[0], n, k, num_groups, static_cast<float>(alpha), beta,
static_cast<ElementC*>(out->data), stream);
}

TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16")
Expand Down
19 changes: 9 additions & 10 deletions src/runtime/contrib/cutlass/group_gemm_runner.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ inline size_t aligned(size_t value, size_t alignment = 16) {
return (value + alignment - 1) / alignment * alignment;
}


template <typename T>
struct KernelTraits;

Expand Down Expand Up @@ -86,8 +85,8 @@ struct CutlassGroupGemmRunner {
using ClusterShape = typename KernelTraits<ElementA>::ClusterShape;
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = typename KernelTraits<ElementA>::KernelSchedule; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch
using KernelSchedule = typename KernelTraits<ElementA>::KernelSchedule; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape,
Expand Down Expand Up @@ -152,8 +151,8 @@ __global__ void prepare_group_gemm_arguments(
ptr_A[group_id] = x + prev_rows * k;
ptr_B[group_id] = weight + group_id * k * n;
ptr_D[group_id] = out + prev_rows * n;
problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows),
static_cast<int>(n), static_cast<int>(k)};
problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows), static_cast<int>(n),
static_cast<int>(k)};
stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0});
stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0});
Expand All @@ -162,7 +161,7 @@ __global__ void prepare_group_gemm_arguments(
template <typename ElementA, typename ElementB, typename ElementC>
void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
ElementC* out, cudaStream_t stream) {
float alpha, float beta, ElementC* out, cudaStream_t stream) {
using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
using StrideA = typename Runner::StrideA;
using StrideB = typename Runner::StrideB;
Expand All @@ -185,11 +184,11 @@ void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t*
offset += aligned(sizeof(StrideB) * num_groups);
StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
offset += aligned(sizeof(StrideC) * num_groups);
prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(
ptr_A, ptr_B, ptr_D, problem_sizes, stride_A, stride_B, stride_D, x, weight, out, indptr, n,
k, num_groups);
prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes,
stride_A, stride_B, stride_D, x,
weight, out, indptr, n, k, num_groups);
offset = aligned(offset, 256);
runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), ptr_D, problem_sizes,
nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset,
workspace_size - offset, num_groups, 1.0f, 0.0f, stream);
workspace_size - offset, num_groups, alpha, beta, stream);
}

0 comments on commit 1469e5e

Please sign in to comment.