diff --git a/Makefile b/Makefile index 14aac70a5..6c6699c95 100644 --- a/Makefile +++ b/Makefile @@ -62,16 +62,40 @@ ROCM_PATH ?= /opt/rocm AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-offload-arch) HIPCC := $(shell which hipcc 2>/dev/null) HIPIFY := $(shell which hipify-perl 2>/dev/null) -HIPCC_FLAGS = -O3 -march=native -I$(BUILD_DIR)/hip +HIPCC_FLAGS = -O3 -march=native -I$(BUILD_DIR)/hip -ffast-math -funsafe-math-optimizations -fno-strict-aliasing HIPCC_FLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) -HIPCC_LDFLAGS = -lhipblas -lhipblaslt -lamdhip64 -ifneq ($(filter gfx1100,$(AMDGPU_TARGETS)),) - HIPCC_LDFLAGS += -ldevice_gemm_operations -lutility -ldevice_other_operations -else - HIPCC_FLAGS += -DDISABLE_CK +ifneq ($(NO_MULTI_GPU), 1) + ifdef RCCL_PATH + HIPCC_FLAGS += -I$(RCCL_PATH)/include + HIPCC_LDFLAGS += -L$(RCCL_PATH) + endif + ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists) + HIPCC_FLAGS += -I/usr/lib/x86_64-linux-gnu/openmpi/include -DMULTI_GPU -DUSE_MPI + HIPCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lmpi -lrccl + endif +endif +ifdef BUILD_XDL + HIPCC_FLAGS += -DBUILD_XDL +endif +ifdef USE_HIPBLAS + ifdef ROCBLAS_PATH + HIPCC_FLAGS += -I$(ROCBLAS_PATH)/include + HIPCC_LDFLAGS += -L$(ROCBLAS_PATH)/library + endif + HIPCC_FLAGS += -DUSE_HIPBLAS + HIPCC_LDFLAGS += -lhipblas endif -ifdef DISABLE_CK - HIPCC_FLAGS += -DDISABLE_CK +ifdef HIPBLASLT_PATH + HIPCC_FLAGS += -I$(HIPBLASLT_PATH)/include + HIPCC_LDFLAGS += -L$(HIPBLASLT_PATH)/library +endif +ifdef USE_CK + ifdef CK_PATH + HIPCC_FLAGS += -I$(CK_PATH)/include -DNEW_CK + HIPCC_LDFLAGS += -I$(CK_PATH)/build/lib + endif + HIPCC_FLAGS += -DUSE_CK + HIPCC_LDFLAGS += -ldevice_gemm_operations -lutility -ldevice_other_operations endif ifdef WAVEFRONTSIZE64 HIPCC_FLAGS += -DWAVEFRONTSIZE64 -mwavefrontsize64 @@ -79,12 +103,6 @@ endif ifdef CUMODE HIPCC_FLAGS += -mcumode endif -ifneq ($(NO_MULTI_GPU), 1) - ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists) - HIPCC_FLAGS += -I/usr/lib/x86_64-linux-gnu/openmpi/include -DMULTI_GPU - HIPCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lmpi -lrccl - endif -endif AMD_HEADERS = $(addprefix $(BUILD_DIR)/hip/,$(wildcard llmc/*h)) # autodect a lot of various supports on current platform @@ -296,6 +314,7 @@ else HIPCC_FLAGS += -DXDNN -I$(XDNN_PATH) HIPCC_LDFLAGS += -L$(XDNN_PATH) -lxdnn endif +HIPCC_LDFLAGS += -lhipblaslt -lamdhip64 $(info ---------------------------------------------) diff --git a/README.md b/README.md index 408f40f21..ff339fd2c 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,15 @@ # llm.c for AMD devices -This is a fork of [Andrej Karpathy's llm.c](https://github.com/karpathy/llm.c) with support for AMD devices. +This is a fork of [Andrej Karpathy's llm.c](https://github.com/karpathy/llm.c) with support for AMD's RDNA and CDNA devices. ## Performance -With default settings on a single 7900 XTX, a training step is currently at ~79ms, compared to ~97ms for PyTorch nightly (2.4.0.dev20240513), and ~440ms for tinygrad. - -For multiple GPU training, on a machine with four 7900 XTX, throughput is at ~210,000 tokens per second. - -Update (5/28/24): Fast attention branch down to 58.340831 ms / training step on single 7900 XTX, or 318777 tok/s on 4x 7900 XTX.. currently working on double buffering to push it even further. - -## Status - -- [x] train_gpt2_fp32 (baseline, minimal changes) -- [x] train_gpt2 with BF16 (baseline, minimal changes) -- [x] train_gpt2 with BF16 and multiple GPUs -- [ ] RDNA3 optimized kernels (in progress) -- [ ] CDNA3 optimized kernels +For the 124M model: +- On a 4x 7900XTX machine, llm.c is ~2.7x faster than PyTorch 2.3.1+rocm6.0 (and ~3.8x faster with optimizations); +- On a 8x MI250X machine, llm.c is ~1.15x faster than PyTorch 2.3.1+rocm6.0 (and ~1.4x faster with optimizations) ## Quick Start (AMD targets) -Install ROCm 6.1.1, checkout the repo, and perform the following steps: +Install latest ROCm, checkout the repo, and perform the following steps: ``` pip install -r requirements.txt @@ -29,12 +19,16 @@ make train_gpt2amd ./train_gpt2amd ``` -The Makefile will build for all AMD targets detected in your machine, but if you wish to only only build for a particular target (e.g., if you have a iGPU that you want to ignore), pass the target arch with AMDGPU_TARGETS like so: +The Makefile will build for all AMD targets detected in your machine, but if you wish to only only build for a particular target (e.g., if you have a iGPU that you want to ignore), pass the target arch with AMDGPU_TARGETS like so: ``` make train_gpt2amd AMDGPU_TARGETS=gfx1100 ``` +## Performance tuning + +Check the Makefile for advanced build options related to performance, e.g., using local builds of Composable Kernels, hipBLAS, hipBLASlt, etc + --- [ORIGINAL README] --- diff --git a/llmc/amd_common.cuh b/llmc/amd_common.cuh index 4cd23ad8d..e650d2111 100644 --- a/llmc/amd_common.cuh +++ b/llmc/amd_common.cuh @@ -1,11 +1,3 @@ -/* - -Goal: unobtrusively provide support for AMD devices with minimal changes to the main CUDA code - -Example (assuming ROCm 6.1.1 installed in /opt/rocm, or ROCM_PATH environment variable is set): - -*/ - #pragma once #ifdef MULTI_GPU @@ -21,157 +13,6 @@ Example (assuming ROCm 6.1.1 installed in /opt/rocm, or ROCM_PATH environment va #define AMD_TARGET_ARCH_CDNA3 #endif -#include - -#ifndef DISABLE_CK - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/ck.hpp" - -template -using S = ck::Sequence; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// cublaslt does not have kernels for gfx11, so best alternative in terms of perf/effort seems to be composite_kernels -// somewhat janky to invoke with all of the templating, but works.. -static inline void matmul_forward_gfx11(hip_bfloat16* out, - const hip_bfloat16* inp, const hip_bfloat16* weight, const hip_bfloat16* bias, - int B, int T, int C, int OC, cudaStream_t stream) { - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using CDEElementOp = ck::tensor_operation::element_wise::Add; - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto cde_element_op = CDEElementOp{}; - - if (bias == NULL) { - auto device_op = ck::tensor_operation::device::DeviceGemmWmma_CShuffle < - ck::tensor_layout::gemm::RowMajor, - ck::tensor_layout::gemm::ColumnMajor, - ck::tensor_layout::gemm::RowMajor, - ck::bhalf_t, - ck::bhalf_t, - ck::bhalf_t, - float, - ck::bhalf_t, - AElementOp, - BElementOp, - CElementOp, - GemmSpec, - 256, - 128, - 256, - 8, - 8, - 16, - 16, - 4, - 4, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, - 1, - S<1, 32, 1, 8>, - 8, - 1>{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument( - reinterpret_cast(const_cast(inp)), - reinterpret_cast(const_cast(weight)), - reinterpret_cast(out), - B*T, - OC, - C, - C, - C, - OC, - a_element_op, - b_element_op, - c_element_op); - invoker.Run(argument, StreamConfig{stream}); - } else { - auto device_op = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle < - ck::tensor_layout::gemm::RowMajor, - ck::tensor_layout::gemm::ColumnMajor, - ck::Tuple, - ck::tensor_layout::gemm::RowMajor, - ck::bhalf_t, - ck::bhalf_t, - ck::Tuple, - ck::bhalf_t, - float, - ck::bhalf_t, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 256, - 128, - 256, - 8, - 8, - 16, - 16, - 4, - 4, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, - 1, - S<1, 32, 1, 8>, - 8>{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument( - reinterpret_cast(const_cast(inp)), - reinterpret_cast(const_cast(weight)), - std::array{reinterpret_cast(const_cast(bias))}, - reinterpret_cast(out), - B*T, - OC, - C, - C, - C, - std::array{0}, - OC, - a_element_op, - b_element_op, - cde_element_op); - invoker.Run(argument, StreamConfig{stream}); - } -} - -#endif - #include #include #include @@ -331,37 +172,3 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } #endif - -namespace cooperative_groups { -template -struct reduce_operator { - static __device__ __forceinline__ T reduce(const T a, const T b) { return a+b; }; -}; - -template -struct plus : public reduce_operator { - static __device__ __forceinline__ T reduce(const T a, const T b) { - return a + b; - } -}; - -template -struct greater : public reduce_operator { - static __device__ __forceinline__ T reduce(const T a, const T b) { - return fmaxf(a, b); - } -}; - -template -static __device__ __forceinline__ float reduce(const thread_block_tile<32>& warp, float x, const plus& op) { - return warp_reduce_sum(x); -} - -template -static __device__ __forceinline__ float reduce(const thread_block_tile<32>& warp, float x, const greater& op) { - return warp_reduce_max(x); -} - -template struct plus; -template struct greater; -} diff --git a/llmc/cublas_common.h b/llmc/cublas_common.h index e658eca2d..2edeabbed 100644 --- a/llmc/cublas_common.h +++ b/llmc/cublas_common.h @@ -31,6 +31,10 @@ const size_t cublaslt_workspace_size = 32 * 1024 * 1024; void* cublaslt_workspace = NULL; cublasComputeType_t cublas_compute = CUBLAS_COMPUTE_32F; cublasLtHandle_t cublaslt_handle; +#if defined(BUILD_AMD) && defined(USE_HIPBLAS) +cublasHandle_t cublas_handle; +void* cublas_workspace = NULL; +#endif // ---------------------------------------------------------------------------- // Error checking diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index e6984309c..3311c8b9a 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -1,6 +1,18 @@ /* Matrix Multiplication, with help from cuBLASLt */ +#ifdef USE_CK +#include +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/ck.hpp" +#endif + #include #include // std::bool_constant // llmc internal imports @@ -104,6 +116,224 @@ __global__ void reduce_add_sum_kernel(floatX* dst, const float* src, size_t n, s // ---------------------------------------------------------------------------- // kernel launchers +#ifdef USE_CK + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +void matmul_ck(__hip_bfloat16*__restrict__ d, const __hip_bfloat16*__restrict__ a, const __hip_bfloat16*__restrict__ b, const __hip_bfloat16*__restrict__ bias, + int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false, + int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0, + bool accumulate=false, __hip_bfloat16*__restrict__ pre_gelu=NULL, bool backward=false) +{ + NVTX_RANGE_FN(); + if (pre_gelu) { printf("%s: GELU in matmul unsupported\n", __PRETTY_FUNCTION__); exit(-1); } + if (transA != true || transB != false) { printf("%s: unsupported transA/B\n", __PRETTY_FUNCTION__); exit(-1); } + if (batch_count != 0 || strideA != 0 || strideB != 0 || strideOut != 0) { printf("%s: batch_count != 0 not supported\n", __PRETTY_FUNCTION__); exit(-1); } + if (accumulate) { printf("%s: accumulate without batch not supported\n", __PRETTY_FUNCTION__); exit(-1); } + + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + using ELayout = ck::tensor_layout::gemm::RowMajor; + + using DataType = ck::bhalf_t; + using AccDataType = float; + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using CDEElementOp = ck::tensor_operation::element_wise::Add; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto StrideA = k; auto StrideB = k; auto StrideC = m; + + if (bias) { +#ifdef BUILD_XDL + auto device_op = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle< + ALayout, BLayout, ck::Tuple, ELayout, + DataType, DataType, AccDataType, DataType, ck::Tuple, DataType, + AElementOp, BElementOp, CDEElementOp, + GemmDefault, + 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>{}; +#else +#ifdef NEW_CK + auto device_op = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, BLayout, ck::Tuple, ELayout, + DataType, DataType, AccDataType, DataType, ck::Tuple, DataType, + AElementOp, BElementOp, CDEElementOp, + GemmDefault, + 2, 128, 64, 128, 64, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, + S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, ck::make_default_loop_scheduler(), ck::PipelineVersion::v2>{}; +#else + auto device_op = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle < + ALayout, BLayout, ck::Tuple, ELayout, + DataType, DataType, ck::Tuple, DataType, AccDataType, DataType, + AElementOp, BElementOp, CDEElementOp, + GemmSpec, + 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>{}; +#endif +#endif + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + reinterpret_cast(const_cast<__hip_bfloat16 *>(b)), + reinterpret_cast(const_cast<__hip_bfloat16 *>(a)), + std::array{reinterpret_cast(const_cast<__hip_bfloat16 *>(bias))}, + reinterpret_cast(d), + n, m, k, StrideA, StrideB, std::array{0}, StrideC, + a_element_op, b_element_op, cde_element_op); + invoker.Run(argument, StreamConfig{stream}); + + } else { +#ifdef BUILD_XDL + auto device_op = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, ELayout, + DataType, DataType, DataType, AccDataType, DataType, + AElementOp, BElementOp, CElementOp, + GemmDefault, + 256, 128, 128, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>{}; +#else +#ifdef NEW_CK + auto device_op = ck::tensor_operation::device::DeviceGemmWmma_CShuffle < + ALayout, BLayout, ELayout, + DataType, DataType, DataType, AccDataType, DataType, + AElementOp, BElementOp, CElementOp, + GemmDefault, + 2, 128, 64, 128, 64, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, + S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, ck::make_default_loop_scheduler(), ck::PipelineVersion::v2>{}; +#else + auto device_op = ck::tensor_operation::device::DeviceGemmWmma_CShuffle < + ALayout, BLayout, ELayout, + DataType, DataType, DataType, AccDataType, DataType, + AElementOp, BElementOp, CElementOp, + GemmSpec, + 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>{}; +#endif +#endif + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + reinterpret_cast(const_cast<__hip_bfloat16 *>(b)), + reinterpret_cast(const_cast<__hip_bfloat16 *>(a)), + reinterpret_cast(d), + n, m, k, StrideA, StrideB, StrideC, +#ifdef BUILD_XDL + 1, +#endif + a_element_op, b_element_op, c_element_op); + invoker.Run(argument, StreamConfig{stream}); + + } +} + +void matmul_ck(floatX*__restrict__ d, const floatX*__restrict__ a, const floatX*__restrict__ b, const floatX*__restrict__ bias, + int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false, + int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0, + bool accumulate=false, floatX*__restrict__ pre_gelu=NULL, bool backward=false) +{ + matmul_ck(reinterpret_cast<__hip_bfloat16*__restrict__>(d), + reinterpret_cast(a), + reinterpret_cast(b), + reinterpret_cast(bias), + m, n, k, stream, transA, transB, batch_count, strideA, strideB, strideOut, accumulate, + reinterpret_cast<__hip_bfloat16*__restrict__>(pre_gelu), backward + ); +} + +#endif +#ifdef USE_HIPBLAS + +__device__ __forceinline__ __hip_bfloat162 __float_as_bfloat162(float x) { + unsigned int temp = __float_as_uint(x); + return *reinterpret_cast<__hip_bfloat162 *>(&temp); +} +__device__ __forceinline__ float __bfloat162_as_float(__hip_bfloat162 x) { + return *reinterpret_cast(&x); +} + +__global__ void add_bias(floatX*__restrict__ out, const floatX*__restrict__ bias, const int rows, const int cols) { + const int tid = threadIdx.x; + const int stride = blockDim.x * 8; + + floatX *__restrict__ p0 = out + (blockIdx.x * cols) + (tid * 8); + const floatX *__restrict__ p1 = bias + (tid * 8); + + for (int i = tid*8; i < cols; i += stride) { + float d0[4], d1[4]; + for(int x=0;x<4;x++) d0[x] = reinterpret_cast(p0)[x]; + for(int x=0;x<4;x++) d1[x] = reinterpret_cast(p1)[x]; + for(int x=0;x<4;x++) { + __hip_bfloat162 t0 = __float_as_bfloat162(d0[x]); + __hip_bfloat162 t1 = __float_as_bfloat162(d1[x]); + t0.x = t0.x + t1.x; + t0.y = t0.y + t1.y; + d0[x] = __bfloat162_as_float(t0); + } + + float *__restrict__ pout = reinterpret_cast(p0); + for(int x=0;x<4;x++) pout[x] = d0[x]; + + p0 += stride; p1 += stride; + } +} + +void matmul_cublas(floatX*__restrict__ d, const floatX*__restrict__ a, const floatX*__restrict__ b, const floatX*__restrict__ bias, + int m, int n, int k, cudaStream_t stream=0, bool transA=true, bool transB=false, + int batch_count=0, size_t strideA=0, size_t strideB=0, size_t strideOut=0, + bool accumulate=false, floatX*__restrict__ pre_gelu=NULL, bool backward=false) +{ + NVTX_RANGE_FN(); + if (pre_gelu != NULL) { printf("%s: GELU unsupported\n", __PRETTY_FUNCTION__); exit(-1); } + + cublasCheck(cublasSetStream(cublas_handle, stream)); + + cublasOperation_t transa = transA? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transb = transB? CUBLAS_OP_T : CUBLAS_OP_N; + + float one = 1.0f, zero = 0.0f; + + const int lda = transA? k : m; + const int ldb = transB? n : k; + const int ldc = m; + + if (batch_count != 0 || strideA != 0 || strideB != 0 || strideOut != 0) { + if (accumulate) { + cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, transa, transb, m, n, k, &one, + a, CUBLAS_LOWP, lda, strideA, b, CUBLAS_LOWP, ldb, strideB, &one, + d, CUBLAS_LOWP, ldc, strideOut, batch_count, cublas_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + cublasCheck(cublasGemmStridedBatchedEx(cublas_handle, transa, transb, m, n, k, &one, + a, CUBLAS_LOWP, lda, strideA, b, CUBLAS_LOWP, ldb, strideB, &zero, + d, CUBLAS_LOWP, ldc, strideOut, batch_count, cublas_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } else { + if (accumulate) { + cublasCheck(cublasGemmEx(cublas_handle, transa, transb, m, n, k, &one, + a, CUBLAS_LOWP, lda, b, CUBLAS_LOWP, ldb, &one, + d, CUBLAS_LOWP, ldc, cublas_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + cublasCheck(cublasGemmEx(cublas_handle, transa, transb, m, n, k, &one, + a, CUBLAS_LOWP, lda, b, CUBLAS_LOWP, ldb, &zero, + d, CUBLAS_LOWP, ldc, cublas_compute, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + + } + if (bias != NULL) { + add_bias<<>>(d, bias, n, m); + cudaCheck(cudaGetLastError()); + } + +} + +#endif + // Wrapper around cublasLtMatmul that is meant to support everything we need in llm.c // https://docs.nvidia.com/cuda/cublas/#cublasltmatmul void matmul_cublaslt(floatX* d, const floatX* a, const floatX* b, const floatX* bias, @@ -234,10 +464,22 @@ void matmul_forward_cublaslt(floatX* out, floatX* pre_gelu=NULL, int gelu_fusion=1) { // By default only fuse GELU for H100+ as cuBLAS seems to be inefficient for fused GELU on Ada/Ampere (?) if (gelu_fusion < 1 && pre_gelu) { +#if defined(USE_CK) + matmul_ck(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); +#elif defined(USE_HIPBLAS) + matmul_cublas(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); +#else matmul_cublaslt(pre_gelu, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, NULL, false); +#endif gelu_forward(out, pre_gelu, B*T*OC, stream); } else { +#if defined(USE_CK) + matmul_ck(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); +#elif defined(USE_HIPBLAS) + matmul_cublas(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); +#else matmul_cublaslt(out, weight, inp, bias, OC, B*T, C, stream, true, false, 0, 0, 0, 0, false, pre_gelu, false); +#endif } } @@ -276,8 +518,13 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, } // backward to input, uses = in the backward pass (set the gradient) +#if defined(USE_HIPBLAS) + matmul_cublas(dinp, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false, + gelu_fusion >= 2 ? pre_gelu : NULL, true); +#else matmul_cublaslt(dinp, weight, dout, NULL, C, B*T, OC, stream, false, false, 0, 0, 0, 0, false, gelu_fusion >= 2 ? pre_gelu : NULL, true); +#endif // backward GELU (if it wasn't fused into the matmul above) if (gelu_fusion < 2 && pre_gelu) { @@ -285,6 +532,11 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, } // backward to weight, uses += in the backward pass (accumulate the gradient) by setting alpha=one +#if defined(USE_HIPBLAS) + matmul_cublas(dweight, inp, dout, NULL /*dbias*/, C, OC, B*T, stream, false, true, 0, 0, 0, 0, + true /* accumulate */, NULL, true); +#else matmul_cublaslt(dweight, inp, dout, NULL /*dbias*/, C, OC, B*T, stream, false, true, 0, 0, 0, 0, true /* accumulate */, NULL, true); +#endif } diff --git a/train_gpt2.cu b/train_gpt2.cu index f7b6644ec..8f5432720 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -9,6 +9,13 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include #include #include +#ifdef USE_CK +#include +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/ck.hpp" +#endif // ----------- CPU utilities ----------- // defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck // defines: create_dir_if_not_exists, find_max_step @@ -209,7 +216,7 @@ typedef struct { float* ln1_rstd; // (L, B, T) floatX* atty; // (L, B, T, C) // cuDNN saves only some statistics information -#if ENABLE_CUDNN +#if defined(ENABLE_CUDNN) || defined(XDNN) float* att; // (L, B, NH, T) #else floatX* att; // (L, B, NH, T, T) @@ -1176,6 +1183,11 @@ void common_start(bool override_enable_tf32 = true, bool print_device_info = tru // set up cuBLAS and cuBLASLt cublasCheck(cublasLtCreate(&cublaslt_handle)); cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); +#if defined(BUILD_AMD) && defined(USE_HIPBLAS) + cublasCheck(cublasCreate(&cublas_handle)); + cudaCheck(cudaMalloc(&cublas_workspace, cublaslt_workspace_size)); + cublasCheck(cublasSetWorkspace(cublas_handle, cublas_workspace, cublaslt_workspace_size)); +#endif // TF32 precision is equivalent to torch.set_float32_matmul_precision('high') bool enable_tf32 = PRECISION_MODE == PRECISION_FP32 && deviceProp.major >= 8 && override_enable_tf32; @@ -1190,6 +1202,10 @@ void common_free(GPT2 &model) { cudaCheck(cudaStreamDestroy(main_stream)); cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasLtDestroy(cublaslt_handle)); +#if defined(BUILD_AMD) && defined(USE_HIPBLAS) + cublasCheck(cublasDestroy(cublas_handle)); + cudaCheck(cudaFree(cublas_workspace)); +#endif #ifdef ENABLE_CUDNN destroy_cudnn(); #endif