From 06dfde3e946d45178b7b242adf9621058b0e3439 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 9 Dec 2023 13:21:09 +0100 Subject: [PATCH] llama : add basic support for offloading moe with CUDA --- ggml-cuda.cu | 33 ++++++++++++++++++++++++--------- ggml.c | 1 - llama.cpp | 46 +++++++++++++++++++++++++++++++++++++--------- 3 files changed, 61 insertions(+), 19 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 04a5d2078941b..ba771870e41ae 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -8242,15 +8242,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s // TODO: mmq/mmv support #endif - const struct ggml_tensor * ids = src0; - const int32_t id = dst->op_params[0]; - const int32_t n_as = dst->op_params[1]; + GGML_ASSERT(dst->backend == GGML_BACKEND_GPU); - const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; + const struct ggml_tensor * ids = src0; + const int32_t id = ((int32_t *) dst->op_params)[0]; + const int32_t n_as = ((int32_t *) dst->op_params)[1]; std::vector ids_host(ggml_nbytes(ids)); - CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); - CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + + if (ids->backend == GGML_BACKEND_GPU) { + const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + } else { + memcpy(ids_host.data(), ids->data, ggml_nbytes(ids)); + } const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra; const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra; @@ -8264,7 +8270,9 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s src1_row.ne[1] = 1; dst_row.ne[1] = 1; - src1_row.extra = &src1_row_extra; + if (src1->backend == GGML_BACKEND_GPU) { + src1_row.extra = &src1_row_extra; + } dst_row.extra = &dst_row_extra; for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { @@ -8278,7 +8286,12 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s const struct ggml_tensor * src0_row = dst->src[row_id + 2]; - src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1]; + if (src1->backend == GGML_BACKEND_GPU) { + src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1]; + } else { + src1_row.data = (char *) src1->data + i01*src1->nb[1]; + } + dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1]; ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row); @@ -8694,7 +8707,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ func = ggml_cuda_repeat; break; case GGML_OP_GET_ROWS: - func = ggml_cuda_get_rows; + if (ggml_is_contiguous(tensor->src[1])) { + func = ggml_cuda_get_rows; + } break; case GGML_OP_DUP: func = ggml_cuda_dup; diff --git a/ggml.c b/ggml.c index 5f94ede0067cc..07d23f4275b54 100644 --- a/ggml.c +++ b/ggml.c @@ -4105,7 +4105,6 @@ struct ggml_tensor * ggml_mul_mat_id( result->src[0] = ids; result->src[1] = b; - // TODO: n_as is the selected experts, but it should be the total number of experts for (int i = 0; i < n_as; i++) { struct ggml_tensor * a = as[i]; GGML_ASSERT(ggml_are_same_shape(as[0], a)); diff --git a/llama.cpp b/llama.cpp index 3b2a6797971dd..c14aab71f308a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4247,16 +4247,25 @@ struct llm_build_context { const int n_experts_per_tok = 2; ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + cb(logits, "ffn_moe_logits", il); + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + cb(probs, "ffn_moe_probs", il); // select experts ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] - ggml_tensor * weights = - ggml_reshape_2d(ctx0, - ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts), + ggml_tensor * weights = ggml_get_rows(ctx0, + ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts); + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx0, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok] - weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok] + + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + cb(weights, "ffn_moe_weights_norm", il); // compute expert outputs ggml_tensor * moe_out; @@ -4269,19 +4278,30 @@ struct llm_build_context { ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp; ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp; - cur_expert = ggml_mul(ctx0, - ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur), - ggml_silu(ctx0, - ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur))); // [n_tokens, n_embd] + ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, ffn_up_exp, n_experts, selected_experts, i, cur); + cb(cur_up, "ffn_up", il); + + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, ffn_gate_exp, n_experts, selected_experts, i, cur); + cb(cur_gate, "ffn_gate", il); + + cur_gate = ggml_silu(ctx0, cur_gate); + cb(cur_gate, "ffn_silu", il); + + cur_expert = ggml_mul(ctx0, cur_up, cur_gate); // [n_tokens, n_embd] + cb(cur_expert, "ffn_gate_par", il); cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cb(cur_expert, "ffn_down", il); + cur_expert = ggml_mul(ctx0, cur_expert, ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + cb(cur_expert, "ffn_moe_weighted", il); if (i == 0) { moe_out = cur_expert; } else { moe_out = ggml_add(ctx0, moe_out, cur_expert); + cb(moe_out, "ffn_moe_out", il); } } @@ -5540,6 +5560,14 @@ static const std::unordered_map k_offload_map { "ffn_relu", OFFLOAD_FUNC }, { "ffn_sqr(relu)", OFFLOAD_FUNC }, + { "ffn_moe_logits", OFFLOAD_FUNC }, + { "ffn_moe_probs", OFFLOAD_FUNC }, + { "ffn_moe_weights", OFFLOAD_FUNC_NOP }, + { "ffn_moe_weights_sum", OFFLOAD_FUNC }, + { "ffn_moe_weights_norm", OFFLOAD_FUNC }, + { "ffn_moe_weighted", OFFLOAD_FUNC }, + { "ffn_moe_out", OFFLOAD_FUNC }, + { "l_out", OFFLOAD_FUNC }, { "result_norm", OFFLOAD_FUNC_EMB },