Skip to content

Commit

Permalink
Fix some issues with CUDA
Browse files Browse the repository at this point in the history
Still broken on Windows AMD :'(
  • Loading branch information
jart committed Apr 21, 2024
1 parent 850f43d commit a49e33d
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 120 deletions.
2 changes: 2 additions & 0 deletions llama.cpp/ggml-backend-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ extern "C" {
void (*GGML_CALL exit)(int);
void (*GGML_CALL free)(void *);
void *(*GGML_CALL malloc)(size_t);
char *(*GGML_CALL getenv)(const char *);
long (*GGML_CALL write)(int, const void *, long);
void (*GGML_CALL ggml_backend_register)(const char *, ggml_backend_init_fn, ggml_backend_buffer_type_t, void *);
ggml_backend_buffer_t (*GGML_CALL ggml_backend_buffer_init)(ggml_backend_buffer_type_t, struct ggml_backend_buffer_i, ggml_backend_buffer_context_t, size_t);
ggml_backend_buffer_t (*GGML_CALL ggml_backend_cpu_buffer_from_ptr)(void *, size_t);
Expand Down
10 changes: 10 additions & 0 deletions llama.cpp/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -2114,11 +2114,21 @@ GGML_CALL static void *system_malloc(size_t n) {
return malloc(n);
}

GGML_CALL static char *system_getenv(const char *s) {
return getenv(s);
}

GGML_CALL static long system_write(int fd, const void *p, long n) {
return write(fd, p, n);
}

static const struct ggml_backend_api kGgmlBackendApi = {
&FLAG_log_disable,
system_exit,
system_free,
system_malloc,
system_getenv,
system_write,
ggml_backend_register,
ggml_backend_buffer_init,
ggml_backend_cpu_buffer_from_ptr,
Expand Down
199 changes: 79 additions & 120 deletions llama.cpp/ggml-cuda.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi

static void ggml_cuda_print(int, const char *, ...);

#include <algorithm>
#include <array>
#include <atomic>
#include <cassert>
#include <cfloat>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <float.h>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include <stdarg.h>

#include "ggml.h"
#include "ggml-cuda.h"

#include <memory>

#if defined(GGML_USE_HIPBLAS)
#define GGML_COMMON_DECL_HIP
#define GGML_COMMON_IMPL_HIP
Expand All @@ -16,12 +33,6 @@
#endif
#include "ggml-common.h"

#include <cstdio>
#include <array>
#include <cassert>
#include <cfloat>
#include <string>

#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
Expand Down Expand Up @@ -202,6 +213,7 @@

static const struct ggml_backend_api *g_backend;
#define exit g_backend->exit
#define getenv g_backend->getenv
#define FLAG_log_disable (*g_backend->FLAG_log_disable)
#define ggml_backend_register g_backend->ggml_backend_register
#define ggml_is_quantized g_backend->ggml_is_quantized
Expand Down Expand Up @@ -230,17 +242,49 @@ static const struct ggml_backend_api *g_backend;
#define ggml_is_empty g_backend->ggml_is_empty
#define ggml_op_desc g_backend->ggml_op_desc

GGML_CALL bool ggml_cuda_link(const struct ggml_backend_api *backend_api) {
g_backend = backend_api;
if (!FLAG_log_disable) {
fprintf(stderr, "%s: welcome to " GGML_CUDA_NAME " SDK with "
// printf() and fprintf() runtime bridge
// this is needed so text gets printed on windows
// it also helps ensure the atomicity of log lines
static void ggml_cuda_print(const char *fmt, ...) {
#define GGML_CUDA_PRINT_BUFSIZ 512
#define fflush(_) (void)0
#define printf(...) ggml_cuda_print(__VA_ARGS__)
#define fprintf(_, ...) ggml_cuda_print(__VA_ARGS__)
int len;
va_list va;
char buf[GGML_CUDA_PRINT_BUFSIZ];
va_start(va, fmt);
len = vsnprintf(buf, GGML_CUDA_PRINT_BUFSIZ, fmt, va);
va_end(va);
if (len < 0)
len = strnlen(buf, GGML_CUDA_PRINT_BUFSIZ);
if (len >= GGML_CUDA_PRINT_BUFSIZ) {
len = GGML_CUDA_PRINT_BUFSIZ;
buf[len - 4] = '.';
buf[len - 3] = '.';
buf[len - 2] = '.';
buf[len - 1] = '\n';
}
g_backend->write(2, buf, len);
}

#ifdef GGML_USE_TINYBLAS
"tinyBLAS"
#define BLAS_NAME "tinyBLAS"
#else
GGML_CUBLAS_NAME
#define BLAS_NAME GGML_CUBLAS_NAME
#endif

GGML_CALL bool ggml_cuda_link(const struct ggml_backend_api *backend_api) {
g_backend = backend_api;
if (!FLAG_log_disable)
fprintf(stderr, "%s: welcome to " GGML_CUDA_NAME " SDK with " BLAS_NAME "\n", __func__);
#ifdef __HIP_PLATFORM_AMD__
// cargo culting workaround below
#ifndef GGML_USE_TINYBLAS
rocblas_initialize();
cudaDeviceSynchronize();
#endif
#endif
"\n", __func__);
}
int device_count;
return cudaGetDeviceCount(&device_count) == cudaSuccess && device_count > 0;
}
Expand Down Expand Up @@ -348,11 +392,11 @@ static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {

#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
(printf)("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
file_name, line, function_name, arch);
GGML_UNUSED(arch_list);
#else
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
(printf)("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
file_name, line, function_name, arch, arch_list);
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
__trap();
Expand Down Expand Up @@ -7880,43 +7924,30 @@ void ggml_cuda_op_mul_mat_vec_q(
GGML_UNUSED(src1_padded_row_size);
}

template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int row = blockIdx.x;
const int tid = threadIdx.x;

float2 mean_var = make_float2(0.f, 0.f);

for (int col = tid; col < ncols; col += block_size) {
for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}

// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}

const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);

for (int col = tid; col < ncols; col += block_size) {
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
}
}

template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
Expand All @@ -7931,117 +7962,62 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr

float tmp = 0.0f; // partial sum for thread in warp

for (int j = start; j < end; j += block_size) {
for (int j = start; j < end; j += WARP_SIZE) {
tmp += x[j];
}

tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}

float mean = tmp / group_size;
tmp = 0.0f;

for (int j = start; j < end; j += block_size) {
for (int j = start; j < end; j += WARP_SIZE) {
float xi = x[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}

tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}

float variance = tmp / group_size;
float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) {
for (int j = start; j < end; j += WARP_SIZE) {
dst[j] *= scale;
}
}

template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps, int nrows) {
const int row = blockIdx.x;
const int tid = threadIdx.x;

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += block_size) {
for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}

// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}

const float mean = tmp / ncols;
const float mean = warp_reduce_sum(tmp) / ncols;
const float scale = rsqrtf(mean + eps);

for (int col = tid; col < ncols; col += block_size) {
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}

static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
norm_f32<<<nrows, WARP_SIZE, 0, stream>>>(x, dst, ncols, eps);
}

static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
static const float eps = 1e-6f;
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
group_norm_f32<<<num_groups, WARP_SIZE, 0, stream>>>(x, dst, group_size, ne_elements, eps);
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
rms_norm_f32<<<nrows, WARP_SIZE, 0, stream>>>(x, dst, ncols, eps, nrows);
}

void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
Expand Down Expand Up @@ -9191,23 +9167,6 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
#include "ggml.h"
#include "ggml-backend-impl.h"


#include <algorithm>
#include <array>
#include <atomic>
#include <cinttypes>
#include <cstddef>
#include <cstdint>
#include <float.h>
#include <limits>
#include <map>
#include <memory>
#include <mutex>
#include <stdint.h>
#include <stdio.h>
#include <string>
#include <vector>

static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");

[[noreturn]]
Expand Down Expand Up @@ -9570,7 +9529,7 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;

if (tensor->view_src != NULL) {
assert(tensor->view_src->buffer->buft == buffer->buft);
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
return;
}

Expand Down Expand Up @@ -11586,10 +11545,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}

#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
GGML_ASSERT(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
GGML_ASSERT(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || ggml_backend_buffer_is_cuda_split(node->src[j]->buffer));
}
}
#endif
Expand Down

0 comments on commit a49e33d

Please sign in to comment.