Skip to content

Commit

Permalink
feat: support sm90 cutlass group gemm (#509)
Browse files Browse the repository at this point in the history
Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
xslingcn and yzh119 authored Oct 9, 2024
1 parent 20265d6 commit 794bdda
Show file tree
Hide file tree
Showing 14 changed files with 528 additions and 56 deletions.
1 change: 0 additions & 1 deletion flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
Expand Down
26 changes: 26 additions & 0 deletions flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>


torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr,
torch::Tensor weight_indices, torch::Tensor x,
torch::Tensor weight, unsigned int batch_size,
bool weight_column_major);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90");
}
18 changes: 17 additions & 1 deletion flashinfer-aot/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(self, *args, **kwargs) -> None:
include_dirs = [
str(root.resolve() / "include"),
str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm
str(root.resolve() / "3rdparty" / "cutlass" / "tools" / "util" / "include"),
]
extra_compile_args = {
"cxx": [
Expand All @@ -371,6 +372,10 @@ def __init__(self, *args, **kwargs) -> None:
"-use_fast_math",
],
}
extra_compile_args_sm90 = extra_compile_args.copy()
extra_compile_args_sm90["nvcc"].extend(
"-gencode arch=compute_90a,code=sm_90a".split()
)
ext_modules = []
ext_modules.append(
torch_cpp_ext.CUDAExtension(
Expand All @@ -385,12 +390,23 @@ def __init__(self, *args, **kwargs) -> None:
"csrc/quantization.cu",
"csrc/group_gemm.cu",
"csrc/bmm_fp8.cu",
"csrc_aot/flashinfer_ops.cu",
"csrc_aot/flashinfer_ops.cu"
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
)
)
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="flashinfer._kernels_sm90",
sources=[
"csrc/group_gemm_sm90.cu",
"csrc_aot/flashinfer_sm90_ops.cu",
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args_sm90,
)
)
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="flashinfer._decode_kernels",
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe

// NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API,
// so I just use the kernel function directly, need to investigate more.
auto compute_args_kernel = compute_cutlass_group_gemm_args<DType>;
auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args<DType, DType>;
compute_args_kernel<<<batch_size, 1, 0, stream>>>(
problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w,
(DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major);
Expand Down Expand Up @@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe

} // namespace flashinfer

#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
57 changes: 45 additions & 12 deletions include/flashinfer/gemm/group_gemm_cutlass.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/util/packed_stride.hpp"

namespace flashinfer {

Expand All @@ -41,21 +46,49 @@ struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};

template <typename T>
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
int64_t* w_indices, size_t d_in, size_t d_out,
bool w_column_major) {
template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
using type = cutlass::float_e4m3_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
using type = cutlass::float_e5m2_t;
};

template <typename DTypeIn, typename DTypeOut>
__global__ void compute_sm80_cutlass_group_gemm_args(
cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
ptr_x[i] = x + xy_indptr[i] * d_in;
ptr_y[i] = y + xy_indptr[i] * d_out;
ld_x[i] = k; // m * k
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
ld_y[i] = n; // m * n
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;
x_ld[i] = k; // m * k
w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
y_ld[i] = n; // m * n
}

template <typename DTypeIn, typename DTypeOut, typename ProblemShape, typename StrideA,
typename StrideB, typename StrideCD>
__global__ void compute_sm90_cutlass_group_gemm_args(
ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = ProblemShape(m, n, k);
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;

x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1})
: cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1});
}

} // namespace group_gemm
Expand Down
Loading

0 comments on commit 794bdda

Please sign in to comment.