-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CUTLASS] Add FP8 gemm kernels (#17408)
This PR introduces the sm90a FP8 kernels from CUTLASS. These kernels are helpful in the cases of small `M`, where cuBLAS has unoptimized performance.
- Loading branch information
1 parent
5648a8e
commit 4e70e4a
Showing
5 changed files
with
349 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#include <cuda_fp16.h> | ||
#include <float.h> | ||
#include <tvm/runtime/ndarray.h> | ||
#include <tvm/runtime/packed_func.h> | ||
#include <tvm/runtime/registry.h> | ||
|
||
#include "../cublas/cublas_utils.h" | ||
#include "gemm_runner.cuh" | ||
|
||
#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) | ||
|
||
struct KernelTraitsM64 { | ||
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; | ||
using TileShape = Shape<_64, _64, _128>; | ||
using ClusterShape = Shape<_1, _8, _1>; | ||
}; | ||
|
||
namespace tvm { | ||
namespace runtime { | ||
|
||
template <typename ElementA, typename ElementB, typename ElementC> | ||
void tvm_cutlass_fp8_gemm(NDArray x, NDArray weight, NDArray workspace, NDArray alpha, | ||
NDArray out) { | ||
// Workspace is used for storing device-side gemm arguments and cutlass internal workspace. | ||
// Recommened size is 4MB. | ||
auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); | ||
ICHECK(func != nullptr); | ||
CHECK_GE(x->ndim, 2); | ||
CHECK_EQ(weight->ndim, 2); | ||
CHECK_EQ(workspace->ndim, 1); | ||
CHECK_GE(out->ndim, 2); | ||
CHECK_EQ(alpha->dtype.code, kDLFloat); | ||
CHECK_EQ(alpha->dtype.bits, 32); | ||
CHECK_EQ(alpha->ndim, 1); | ||
CHECK_EQ(alpha->shape[0], 1); | ||
int64_t m = 1; | ||
for (int i = 0; i < x->ndim - 1; ++i) { | ||
m *= x->shape[i]; | ||
} | ||
int64_t n = weight->shape[0]; | ||
CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight is supported now."; | ||
int64_t k = x->shape[x->ndim - 1]; | ||
const float* beta = nullptr; | ||
cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*()); | ||
if (m <= 64) { | ||
cutlass_gemm<KernelTraitsM64>( | ||
static_cast<ElementA*>(x->data), static_cast<ElementB*>(weight->data), | ||
static_cast<uint8_t*>(workspace->data), workspace->shape[0], m, n, k, | ||
static_cast<float*>(alpha->data), beta, static_cast<ElementC*>(out->data), stream); | ||
} else { | ||
tvm::contrib::CuBlasLtThreadEntry* cublas_entry = | ||
tvm::contrib::CuBlasLtThreadEntry::ThreadLocal(); | ||
tvm::contrib::CallCublasLt(cublas_entry->handle, stream, cublas_entry->matmul_pref_desc, | ||
x.operator->(), weight.operator->(), nullptr, alpha.operator->(), | ||
nullptr, out.operator->(), /*transa=*/false, /*transb=*/true, | ||
cublas_entry->workspace_ptr, cublas_entry->workspace_size, | ||
CUBLASLT_EPILOGUE_DEFAULT, std::nullopt); | ||
} | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e5m2_fp16") | ||
.set_body_typed( | ||
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e5m2_t, cutlass::half_t>); | ||
|
||
TVM_REGISTER_GLOBAL("cutlass.gemm_e5m2_e4m3_fp16") | ||
.set_body_typed( | ||
tvm_cutlass_fp8_gemm<cutlass::float_e5m2_t, cutlass::float_e4m3_t, cutlass::half_t>); | ||
|
||
TVM_REGISTER_GLOBAL("cutlass.gemm_e4m3_e4m3_fp16") | ||
.set_body_typed( | ||
tvm_cutlass_fp8_gemm<cutlass::float_e4m3_t, cutlass::float_e4m3_t, cutlass::half_t>); | ||
|
||
} // namespace runtime | ||
} // namespace tvm | ||
|
||
#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
#include <fstream> | ||
#include <iostream> | ||
#include <sstream> | ||
#include <variant> | ||
#include <vector> | ||
|
||
#include "../../cuda/cuda_common.h" | ||
|
||
// clang-format off | ||
#include "cutlass/cutlass.h" | ||
|
||
#include "cute/tensor.hpp" | ||
#include "cutlass/tensor_ref.h" | ||
#include "cutlass/epilogue/collective/default_epilogue.hpp" | ||
#include "cutlass/epilogue/thread/linear_combination.h" | ||
#include "cutlass/gemm/dispatch_policy.hpp" | ||
#include "cutlass/gemm/gemm.h" | ||
#include "cutlass/gemm/collective/collective_builder.hpp" | ||
#include "cutlass/epilogue/collective/collective_builder.hpp" | ||
#include "cutlass/gemm/device/gemm_universal_adapter.h" | ||
#include "cutlass/gemm/kernel/gemm_universal.hpp" | ||
// clang-format on | ||
|
||
#define CUTLASS_CHECK(status) \ | ||
{ \ | ||
cutlass::Status error = status; \ | ||
CHECK(error == cutlass::Status::kSuccess) \ | ||
<< "Got cutlass error: " << cutlassGetStatusString(error); \ | ||
} | ||
|
||
using namespace cute; | ||
using ProblemShape = Shape<int, int, int>; // <M, N, K> | ||
|
||
template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC, | ||
typename LayoutA = cutlass::layout::RowMajor, | ||
typename LayoutB = cutlass::layout::ColumnMajor, | ||
typename LayoutC = cutlass::layout::RowMajor> | ||
struct CutlassGemmRunner { | ||
static constexpr int AlignmentA = | ||
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements | ||
// (up to 16 bytes) | ||
|
||
static constexpr int AlignmentB = | ||
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements | ||
// (up to 16 bytes) | ||
|
||
static constexpr int AlignmentC = | ||
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements | ||
// (up to 16 bytes) | ||
|
||
// Core kernel configurations | ||
using ElementAccumulator = float; // Element type for internal accumulation | ||
using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>; | ||
using ArchTag = | ||
cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature | ||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag | ||
using TileShape = typename KernelTraits::TileShape; | ||
using ClusterShape = typename KernelTraits::ClusterShape; | ||
using StageCountType = | ||
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size | ||
using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch | ||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue to launch | ||
|
||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< | ||
ArchTag, OperatorClass, TileShape, ClusterShape, | ||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, | ||
ElementC, LayoutC, AlignmentC, ElementC, LayoutC, AlignmentC, EpilogueSchedule>::CollectiveOp; | ||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< | ||
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, | ||
ElementAccumulator, TileShape, ClusterShape, | ||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||
sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||
KernelSchedule>::CollectiveOp; | ||
|
||
using GemmKernel = | ||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>; | ||
|
||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; | ||
|
||
using StrideA = typename Gemm::GemmKernel::StrideA; | ||
using StrideB = typename Gemm::GemmKernel::StrideB; | ||
using StrideC = typename Gemm::GemmKernel::StrideC; | ||
using StrideD = typename Gemm::GemmKernel::StrideD; | ||
|
||
void run_gemm(const ElementA* ptr_A, const ElementB* ptr_B, const ElementC* ptr_C, | ||
ElementC* ptr_D, ProblemShape* problem_size, StrideA* stride_A, StrideB* stride_B, | ||
StrideC* stride_C, StrideD* stride_D, uint8_t* workspace, int64_t workspace_size, | ||
ScaleType alpha, ScaleType beta, cudaStream_t stream) { | ||
cutlass::KernelHardwareInfo hw_info; | ||
hw_info.device_id = 0; | ||
hw_info.sm_count = | ||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); | ||
typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, | ||
*problem_size, | ||
{ptr_A, *stride_A, ptr_B, *stride_B}, | ||
{{}, ptr_C, *stride_C, ptr_D, *stride_D}, | ||
// {epilogue_params, ptr_C, *stride_C, ptr_D, *stride_D}, | ||
hw_info}; | ||
|
||
ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; | ||
if (std::holds_alternative<ElementAccumulator>(alpha)) { | ||
arguments.epilogue.thread.alpha = std::get<ElementAccumulator>(alpha); | ||
arguments.epilogue.thread.beta = std::get<ElementAccumulator>(beta); | ||
} else if (std::holds_alternative<const ElementAccumulator*>(alpha)) { | ||
arguments.epilogue.thread.alpha_ptr = std::get<const ElementAccumulator*>(alpha); | ||
arguments.epilogue.thread.beta_ptr = std::get<const ElementAccumulator*>(beta); | ||
} else { | ||
LOG(FATAL) << "Unsupported alpha and beta type"; | ||
throw; | ||
} | ||
|
||
Gemm gemm_op; | ||
CUTLASS_CHECK(gemm_op.can_implement(arguments)); | ||
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); | ||
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); | ||
CUTLASS_CHECK(gemm_op.run(stream)); | ||
} | ||
}; | ||
|
||
template <typename KernelTraits, typename ElementA, typename ElementB, typename ElementC> | ||
void cutlass_gemm(ElementA* x, ElementB* weight, uint8_t* workspace, int64_t workspace_size, | ||
int64_t m, int64_t n, int64_t k, std::variant<float, const float*> alpha, | ||
std::variant<float, const float*> beta, ElementC* out, cudaStream_t stream) { | ||
using Runner = CutlassGemmRunner<KernelTraits, ElementA, ElementB, ElementC>; | ||
using StrideA = typename Runner::StrideA; | ||
using StrideB = typename Runner::StrideB; | ||
using StrideC = typename Runner::StrideC; | ||
|
||
Runner runner; | ||
StrideA stride_A = cute::make_stride(k, Int<1>{}, int64_t{0}); | ||
StrideB stride_B = cute::make_stride(k, Int<1>{}, int64_t{0}); | ||
StrideC stride_D = cute::make_stride(n, Int<1>{}, int64_t{0}); | ||
ProblemShape problem_size{static_cast<int>(m), static_cast<int>(n), static_cast<int>(k)}; | ||
runner.run_gemm(x, weight, out, out, &problem_size, &stride_A, &stride_B, &stride_D, &stride_D, | ||
workspace, workspace_size, alpha, beta, stream); | ||
} |
Oops, something went wrong.