Skip to content

Commit

Permalink
Cleanup and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonix committed Jul 9, 2024
1 parent 0ed210b commit f420545
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 224 deletions.
47 changes: 33 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,47 @@ ROCM_PATH ?= /opt/rocm
AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-offload-arch)
HIPCC := $(shell which hipcc 2>/dev/null)
HIPIFY := $(shell which hipify-perl 2>/dev/null)
HIPCC_FLAGS = -O3 -march=native -I$(BUILD_DIR)/hip
HIPCC_FLAGS = -O3 -march=native -I$(BUILD_DIR)/hip -ffast-math -funsafe-math-optimizations -fno-strict-aliasing
HIPCC_FLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
HIPCC_LDFLAGS = -lhipblas -lhipblaslt -lamdhip64
ifneq ($(filter gfx1100,$(AMDGPU_TARGETS)),)
HIPCC_LDFLAGS += -ldevice_gemm_operations -lutility -ldevice_other_operations
else
HIPCC_FLAGS += -DDISABLE_CK
ifneq ($(NO_MULTI_GPU), 1)
ifdef RCCL_PATH
HIPCC_FLAGS += -I$(RCCL_PATH)/include
HIPCC_LDFLAGS += -L$(RCCL_PATH)
endif
ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists)
HIPCC_FLAGS += -I/usr/lib/x86_64-linux-gnu/openmpi/include -DMULTI_GPU -DUSE_MPI
HIPCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lmpi -lrccl
endif
endif
ifdef BUILD_XDL
HIPCC_FLAGS += -DBUILD_XDL
endif
ifdef USE_HIPBLAS
ifdef ROCBLAS_PATH
HIPCC_FLAGS += -I$(ROCBLAS_PATH)/include
HIPCC_LDFLAGS += -L$(ROCBLAS_PATH)/library
endif
HIPCC_FLAGS += -DUSE_HIPBLAS
HIPCC_LDFLAGS += -lhipblas
endif
ifdef DISABLE_CK
HIPCC_FLAGS += -DDISABLE_CK
ifdef HIPBLASLT_PATH
HIPCC_FLAGS += -I$(HIPBLASLT_PATH)/include
HIPCC_LDFLAGS += -L$(HIPBLASLT_PATH)/library
endif
ifdef USE_CK
ifdef CK_PATH
HIPCC_FLAGS += -I$(CK_PATH)/include -DNEW_CK
HIPCC_LDFLAGS += -I$(CK_PATH)/build/lib
endif
HIPCC_FLAGS += -DUSE_CK
HIPCC_LDFLAGS += -ldevice_gemm_operations -lutility -ldevice_other_operations
endif
ifdef WAVEFRONTSIZE64
HIPCC_FLAGS += -DWAVEFRONTSIZE64 -mwavefrontsize64
endif
ifdef CUMODE
HIPCC_FLAGS += -mcumode
endif
ifneq ($(NO_MULTI_GPU), 1)
ifeq ($(shell [ -d /usr/lib/x86_64-linux-gnu/openmpi/lib/ ] && [ -d /usr/lib/x86_64-linux-gnu/openmpi/include/ ] && echo "exists"), exists)
HIPCC_FLAGS += -I/usr/lib/x86_64-linux-gnu/openmpi/include -DMULTI_GPU
HIPCC_LDFLAGS += -L/usr/lib/x86_64-linux-gnu/openmpi/lib/ -lmpi -lrccl
endif
endif
AMD_HEADERS = $(addprefix $(BUILD_DIR)/hip/,$(wildcard llmc/*h))

# autodect a lot of various supports on current platform
Expand Down Expand Up @@ -296,6 +314,7 @@ else
HIPCC_FLAGS += -DXDNN -I$(XDNN_PATH)
HIPCC_LDFLAGS += -L$(XDNN_PATH) -lxdnn
endif
HIPCC_LDFLAGS += -lhipblaslt -lamdhip64

$(info ---------------------------------------------)

Expand Down
26 changes: 10 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,15 @@
# llm.c for AMD devices
This is a fork of [Andrej Karpathy's llm.c](https://github.com/karpathy/llm.c) with support for AMD devices.
This is a fork of [Andrej Karpathy's llm.c](https://github.com/karpathy/llm.c) with support for AMD's RDNA and CDNA devices.

## Performance

With default settings on a single 7900 XTX, a training step is currently at ~79ms, compared to ~97ms for PyTorch nightly (2.4.0.dev20240513), and ~440ms for tinygrad.

For multiple GPU training, on a machine with four 7900 XTX, throughput is at ~210,000 tokens per second.

Update (5/28/24): Fast attention branch down to 58.340831 ms / training step on single 7900 XTX, or 318777 tok/s on 4x 7900 XTX.. currently working on double buffering to push it even further.

## Status

- [x] train_gpt2_fp32 (baseline, minimal changes)
- [x] train_gpt2 with BF16 (baseline, minimal changes)
- [x] train_gpt2 with BF16 and multiple GPUs
- [ ] RDNA3 optimized kernels (in progress)
- [ ] CDNA3 optimized kernels
For the 124M model:
- On a 4x 7900XTX machine, llm.c is ~2.7x faster than PyTorch 2.3.1+rocm6.0 (and ~3.8x faster with optimizations);
- On a 8x MI250X machine, llm.c is ~1.15x faster than PyTorch 2.3.1+rocm6.0 (and ~1.4x faster with optimizations)

## Quick Start (AMD targets)

Install ROCm 6.1.1, checkout the repo, and perform the following steps:
Install latest ROCm, checkout the repo, and perform the following steps:

```
pip install -r requirements.txt
Expand All @@ -29,12 +19,16 @@ make train_gpt2amd
./train_gpt2amd
```

The Makefile will build for all AMD targets detected in your machine, but if you wish to only only build for a particular target (e.g., if you have a iGPU that you want to ignore), pass the target arch with AMDGPU_TARGETS like so:
The Makefile will build for all AMD targets detected in your machine, but if you wish to only only build for a particular target (e.g., if you have a iGPU that you want to ignore), pass the target arch with AMDGPU_TARGETS like so:

```
make train_gpt2amd AMDGPU_TARGETS=gfx1100
```

## Performance tuning

Check the Makefile for advanced build options related to performance, e.g., using local builds of Composable Kernels, hipBLAS, hipBLASlt, etc

---
[ORIGINAL README]
---
Expand Down
193 changes: 0 additions & 193 deletions llmc/amd_common.cuh
Original file line number Diff line number Diff line change
@@ -1,11 +1,3 @@
/*
Goal: unobtrusively provide support for AMD devices with minimal changes to the main CUDA code
Example (assuming ROCm 6.1.1 installed in /opt/rocm, or ROCM_PATH environment variable is set):
*/

#pragma once

#ifdef MULTI_GPU
Expand All @@ -21,157 +13,6 @@ Example (assuming ROCm 6.1.1 installed in /opt/rocm, or ROCM_PATH environment va
#define AMD_TARGET_ARCH_CDNA3
#endif

#include <hip/hip_bfloat16.h>

#ifndef DISABLE_CK

#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/ck.hpp"

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

// cublaslt does not have kernels for gfx11, so best alternative in terms of perf/effort seems to be composite_kernels
// somewhat janky to invoke with all of the templating, but works..
static inline void matmul_forward_gfx11(hip_bfloat16* out,
const hip_bfloat16* inp, const hip_bfloat16* weight, const hip_bfloat16* bias,
int B, int T, int C, int OC, cudaStream_t stream) {
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;

auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
auto cde_element_op = CDEElementOp{};

if (bias == NULL) {
auto device_op = ck::tensor_operation::device::DeviceGemmWmma_CShuffle <
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::tensor_layout::gemm::RowMajor,
ck::bhalf_t,
ck::bhalf_t,
ck::bhalf_t,
float,
ck::bhalf_t,
AElementOp,
BElementOp,
CElementOp,
GemmSpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8,
1>{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(inp)),
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(weight)),
reinterpret_cast<ck::bhalf_t*>(out),
B*T,
OC,
C,
C,
C,
OC,
a_element_op,
b_element_op,
c_element_op);
invoker.Run(argument, StreamConfig{stream});
} else {
auto device_op = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle <
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::ColumnMajor,
ck::Tuple<ck::tensor_layout::gemm::RowMajor>,
ck::tensor_layout::gemm::RowMajor,
ck::bhalf_t,
ck::bhalf_t,
ck::Tuple<ck::bhalf_t>,
ck::bhalf_t,
float,
ck::bhalf_t,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
256,
128,
256,
8,
8,
16,
16,
4,
4,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
1,
1,
S<1, 32, 1, 8>,
8>{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(inp)),
reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(weight)),
std::array<const void*, 1>{reinterpret_cast<ck::bhalf_t*>(const_cast<hip_bfloat16 *>(bias))},
reinterpret_cast<ck::bhalf_t*>(out),
B*T,
OC,
C,
C,
C,
std::array<ck::index_t, 1>{0},
OC,
a_element_op,
b_element_op,
cde_element_op);
invoker.Run(argument, StreamConfig{stream});
}
}

#endif

#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
Expand Down Expand Up @@ -331,37 +172,3 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}
#endif

namespace cooperative_groups {
template <typename T>
struct reduce_operator {
static __device__ __forceinline__ T reduce(const T a, const T b) { return a+b; };
};

template <typename T>
struct plus : public reduce_operator<T> {
static __device__ __forceinline__ T reduce(const T a, const T b) {
return a + b;
}
};

template <typename T>
struct greater : public reduce_operator<T> {
static __device__ __forceinline__ T reduce(const T a, const T b) {
return fmaxf(a, b);
}
};

template <typename T>
static __device__ __forceinline__ float reduce(const thread_block_tile<32>& warp, float x, const plus<T>& op) {
return warp_reduce_sum(x);
}

template <typename T>
static __device__ __forceinline__ float reduce(const thread_block_tile<32>& warp, float x, const greater<T>& op) {
return warp_reduce_max(x);
}

template struct plus<float>;
template struct greater<float>;
}
4 changes: 4 additions & 0 deletions llmc/cublas_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ const size_t cublaslt_workspace_size = 32 * 1024 * 1024;
void* cublaslt_workspace = NULL;
cublasComputeType_t cublas_compute = CUBLAS_COMPUTE_32F;
cublasLtHandle_t cublaslt_handle;
#if defined(BUILD_AMD) && defined(USE_HIPBLAS)
cublasHandle_t cublas_handle;
void* cublas_workspace = NULL;
#endif

// ----------------------------------------------------------------------------
// Error checking
Expand Down
Loading

0 comments on commit f420545

Please sign in to comment.