Skip to content

Commit

Permalink
CUDA: mul_mat_id always on GPU for batches >= 32
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Dec 20, 2023
1 parent 799fc22 commit 54a9a77
Showing 1 changed file with 28 additions and 7 deletions.
35 changes: 28 additions & 7 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8782,8 +8782,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
// TODO: mmq/mmv support
#endif

GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);

const int64_t nb11 = src1->nb[1];
const int64_t nb1 = dst->nb[1];

Expand Down Expand Up @@ -8812,13 +8810,24 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;

ggml_backend_type src1_original_backend = src1_row.backend;
ggml_backend_type dst_original_backend = dst_row.backend;

src1_row.backend = GGML_BACKEND_GPU;
dst_row.backend = GGML_BACKEND_GPU;

src1_row.extra = &src1_row_extra;
dst_row.extra = &dst_row_extra;

char * src1_original = (char *) src1_extra->data_device[g_main_device];
char * dst_original = (char *) dst_extra->data_device[g_main_device];
char * src1_original = src1_original_backend == GGML_BACKEND_CPU ?
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
char * dst_original = dst_original_backend == GGML_BACKEND_CPU ?
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];

if (src1->ne[1] == 1) {
GGML_ASSERT(src1_original_backend == GGML_BACKEND_GPU);
GGML_ASSERT(dst_original_backend == GGML_BACKEND_GPU);

for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
//int32_t row_id;
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
Expand Down Expand Up @@ -8846,6 +8855,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
src1_row_extra.data_device[g_main_device] = src1_contiguous;
dst_row_extra.data_device[g_main_device] = dst_contiguous;

const cudaMemcpyKind src1_kind = src1_original_backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
const cudaMemcpyKind dst_kind = src1_original_backend == GGML_BACKEND_CPU ?
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;

for (int32_t row_id = 0; row_id < n_as; ++row_id) {
const struct ggml_tensor * src0_row = dst->src[row_id + 2];

Expand All @@ -8860,7 +8874,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as);

CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, cudaMemcpyDeviceToDevice, stream));
nb11, src1_kind, stream));
num_src1_rows++;
}

Expand Down Expand Up @@ -8892,14 +8906,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
GGML_ASSERT(row_id >= 0 && row_id < n_as);

CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
nb1, dst_kind, stream));
num_src1_rows++;
}
}

ggml_cuda_pool_free(src1_contiguous, as_src1);
ggml_cuda_pool_free(dst_contiguous, as_dst);
}

src1_row.backend = src1_original_backend;

if (dst_original_backend == GGML_BACKEND_CPU) {
CUDA_CHECK(cudaStreamSynchronize(stream));
}
dst_row.backend = dst_original_backend;
}

static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
Expand Down Expand Up @@ -9298,7 +9319,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);

if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
return false;
}

Expand Down

0 comments on commit 54a9a77

Please sign in to comment.