Skip to content

Commit

Permalink
bring ggml v2_cuda up to date with AMD changes
Browse files Browse the repository at this point in the history
  • Loading branch information
YellowRoseCx committed Aug 25, 2023
1 parent f915a46 commit 3385dd4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
7 changes: 6 additions & 1 deletion otherarch/ggml_v2-cuda-legacy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
Expand All @@ -32,6 +36,7 @@
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
Expand All @@ -54,7 +59,7 @@
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#else
Expand Down
8 changes: 6 additions & 2 deletions otherarch/ggml_v2-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
Expand All @@ -32,6 +36,7 @@
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
Expand All @@ -54,14 +59,13 @@
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>

#endif

#include "ggml_v2-cuda.h"
Expand Down

0 comments on commit 3385dd4

Please sign in to comment.