diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 5fa4942be7aaf..26563821a1078 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { } fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03e801c2a6d4b..5192d6df5c2f8 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,7 +35,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -93,5 +93,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_backend_free(); + return 0; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index ef57a8982c64a..07d8fc6ac0781 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -671,7 +671,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 68f44ba805966..7e120ff12cb42 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -147,7 +147,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1eb0f75d6dc79..797d2f0c5a279 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { usage(argv[0]); } - llama_init_backend(false); + llama_backend_init(false); // parse command line arguments const std::string fname_inp = argv[arg_idx]; @@ -257,5 +257,7 @@ int main(int argc, char ** argv) { printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); } + llama_backend_free(); + return 0; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cbfc0018de3a..296c5d6468f16 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1079,7 +1079,7 @@ int main(int argc, char **argv) params.model_alias = params.model; } - llama_init_backend(params.numa); + llama_backend_init(params.numa); LOG_INFO("build info", {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); @@ -1309,5 +1309,7 @@ int main(int argc, char **argv) return 1; } + llama_backend_free(); + return 0; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 57a0fb7c5585d..aa2c4352df294 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -66,7 +66,7 @@ int main(int argc, char ** argv) // Init LLM : //--------------------------------- - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -173,7 +173,7 @@ int main(int argc, char ** argv) llama_free( ctx ); llama_free_model( model ); - llama_finalize_backend(); + llama_backend_free(); return 0; } diff --git a/ggml-mpi.c b/ggml-mpi.c index bf301d08b5aee..4bde418089f8c 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -2,80 +2,211 @@ #include "ggml.h" +#include + #include #include -#include + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + #define UNUSED GGML_UNUSED -struct ggml_mpi_tensor_info { +struct ggml_mpi_context { int rank; + int size; }; -// ggml_compute_forward_send +void ggml_mpi_backend_init(void) { + MPI_Init(NULL, NULL); +} -static void ggml_mpi_compute_forward_send( - struct ggml_tensor * src, - const struct ggml_tensor * orig) { - UNUSED(orig); - GGML_ASSERT(src->type == GGML_TYPE_F32); +void ggml_mpi_backend_free(void) { + MPI_Finalize(); +} + +struct ggml_mpi_context * ggml_mpi_init(void) { + struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context)); - int my_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); + MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); - int dst_rank = ((struct ggml_mpi_tensor_info *)src->extra)->rank; - // fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra); - int retval = MPI_Send(src->data, ggml_nelements(src), MPI_FLOAT, dst_rank, 0, MPI_COMM_WORLD); - // fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra); - GGML_ASSERT(retval == MPI_SUCCESS); + return ctx; } -// ggml_compute_forward_recv - -static void ggml_mpi_compute_forward_recv( - struct ggml_tensor * dst, - const struct ggml_tensor * orig, - const struct ggml_tensor * parent) { - UNUSED(parent); - UNUSED(orig); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - MPI_Status status; - - int my_rank; - MPI_Comm_rank(MPI_COMM_WORLD, &my_rank); - - int src_rank = ((struct ggml_mpi_tensor_info *)dst->extra)->rank; - // fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, src_extra); - int retval = MPI_Recv(dst->data, ggml_nelements(dst), MPI_FLOAT, src_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status); - // fprintf(stderr, "(%d) Received from (%d)\n", my_rank, src_extra); - GGML_ASSERT(retval == MPI_SUCCESS); +void ggml_mpi_free(struct ggml_mpi_context * ctx) { + free(ctx); } -struct ggml_tensor * ggml_mpi_send_tensor( - struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank) { +int ggml_mpi_rank(struct ggml_mpi_context * ctx) { + return ctx->rank; +} - struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send); +void ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int * n_tokens, + int * n_past, + int * n_threads) { + UNUSED(ctx_mpi); - // TODO how/when to free this struct? - struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info)); - info->rank = dst_rank; - result->extra = info; + // synchronize the worker node parameters with the root node + MPI_Barrier(MPI_COMM_WORLD); + + MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); +} - return result; +int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { + struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); + if (t == NULL) { + fprintf(stderr, "%s: tensor %s not found\n", __func__, name); + return -1; + } + + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->nodes[i] == t) { + return i; + } + } + + fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name); + return -1; } -struct ggml_tensor * ggml_mpi_recv_tensor( - struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank) { - struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv); +// TODO: there are many improvements that can be done to this implementation +void ggml_mpi_graph_compute( + struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + int n_layers) { + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; + + struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); + if (inp_tokens == NULL) { + fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); + return; + } + + struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (inp0 == NULL) { + fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); + return; + } + + GGML_ASSERT(inp0 == gf->nodes[0]); + + // distribute the compute graph into slices across the MPI nodes + // + // the main node (0) processes the last layers + the remainder of the compute graph + // and is responsible to pass the input tokens to the first node (1) + // + // node 1: [( 0) * n_per_node, ( 1) * n_per_node) + // node 2: [( 1) * n_per_node, ( 2) * n_per_node) + // ... + // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) + // node 0: [(n-1) * n_per_node, n_nodes) + // + if (mpi_rank > 0) { + if (mpi_rank == 1) { // the first node receives the input tokens from the main node + MPI_Status status; UNUSED(status); + + const int mpi_rank_src = mpi_rank - 1; + + const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); + } else { // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) + MPI_Status status; UNUSED(status); + + const int mpi_rank_src = mpi_rank - 1; + + //printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); + const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); + } + } else if (mpi_size > 1) { + // node 0 sends the input tokens to node 1 + { + const int mpi_rank_dst = mpi_rank + 1; + + const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); + } + + // recv the output data from the last node + { + MPI_Status status; UNUSED(status); + + const int mpi_rank_src = mpi_size - 1; + + //fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src); + const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); + } + } + + { + const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; + + const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1; + + const int il0 = (mpi_idx + 0) * n_per_node; + const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); + + char name_l0[GGML_MAX_NAME]; + char name_l1[GGML_MAX_NAME]; + + snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); + snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); + + const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); + const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; + + if (idx_l0 < 0 || idx_l1 < 0) { + fprintf(stderr, "%s: layer input nodes not found\n", __func__); + return; + } + + // attach the input data to all nodes that need it + // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) + for (int i = idx_l0; i < idx_l1; i++) { + if (gf->nodes[i]->src0 == gf->nodes[idx_l0]) { + gf->nodes[i]->src0 = inp0; + } + if (gf->nodes[i]->src1 == gf->nodes[idx_l0]) { + gf->nodes[i]->src1 = inp0; + } + } + + // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph + for (int i = 1; i < idx_l1 - idx_l0; i++) { + gf->nodes[i] = gf->nodes[idx_l0 + i]; + gf->grads[i] = gf->grads[idx_l0 + i]; + } + + // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node + if (mpi_idx != 0) { + gf->nodes[0]->op = GGML_OP_NONE; + } + + gf->n_nodes = idx_l1 - idx_l0; + + //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); + } + + ggml_graph_compute(ctx, gf); + + //fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank); + + // send the output data to the next node + if (mpi_rank > 0) { + struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1]; + + const int mpi_rank_dst = (mpi_rank + 1) % mpi_size; - // TODO how/when to free this struct? - struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info)); - info->rank = src_rank; - result->extra = info; + //fprintf(stderr, "%s: node %d: sending %d elements to node %d\n", __func__, mpi_rank, ggml_nelements(output), mpi_rank_dst); - return result; + const int retval = MPI_Send(output->data, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); + } } diff --git a/ggml-mpi.h b/ggml-mpi.h index ef5269dc5c74f..02e125cfb624b 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -2,20 +2,33 @@ struct ggml_context; struct ggml_tensor; +struct ggml_cgraph; #ifdef __cplusplus extern "C" { #endif -struct ggml_tensor * ggml_mpi_send_tensor( - struct ggml_context * ctx, - struct ggml_tensor *src, - int dst_rank); -struct ggml_tensor * ggml_mpi_recv_tensor( - struct ggml_context * ctx, - struct ggml_tensor *parent, - struct ggml_tensor *dst, - int src_rank); +struct ggml_mpi_context; + +void ggml_mpi_backend_init(void); +void ggml_mpi_backend_free(void); + +struct ggml_mpi_context * ggml_mpi_init(void); +void ggml_mpi_free(struct ggml_mpi_context * ctx); + +int ggml_mpi_rank(struct ggml_mpi_context * ctx); + +void ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int * n_tokens, + int * n_past, + int * n_threads); + +void ggml_mpi_graph_compute( + struct ggml_mpi_context * ctx_mpi, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + int n_layers); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index a0333b534948b..325db7d56b4c5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -52,10 +52,6 @@ #include #include -#ifdef GGML_USE_MPI -#include -#endif - #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif @@ -359,8 +355,9 @@ struct llama_context { ggml_metal_context * ctx_metal = NULL; #endif - int mpi_rank; - int mpi_size; +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = NULL; +#endif int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -880,7 +877,7 @@ bool llama_mlock_supported() { return llama_mlock::SUPPORTED; } -void llama_init_backend(bool numa) { +void llama_backend_init(bool numa) { ggml_time_init(); // needed to initialize f16 tables @@ -893,14 +890,15 @@ void llama_init_backend(bool numa) { if (numa) { ggml_numa_init(); } + #ifdef GGML_USE_MPI - MPI_Init(NULL, NULL); + ggml_mpi_backend_init(); #endif } -void llama_finalize_backend() { +void llama_backend_free() { #ifdef GGML_USE_MPI - MPI_Finalize(); + ggml_mpi_backend_free(); #endif } @@ -1303,13 +1301,17 @@ static bool llama_eval_internal( llama_context & lctx, const llama_token * tokens, const float * embd, - const int n_tokens, - const int n_past, + int n_tokens, + int n_past, int n_threads, const char * cgraph_fname) { LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); +#ifdef GGML_USE_MPI + ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); +#endif + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1349,21 +1351,17 @@ static bool llama_eval_internal( struct ggml_tensor * cur; struct ggml_tensor * inpL; - if (lctx.mpi_rank > 0) { + if (tokens) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { #ifdef GGML_USE_MPI - inpL = ggml_mpi_recv_tensor(ctx0, NULL, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), - lctx.mpi_rank-1); - ggml_set_name(inpL, "mpi_recv"); -#else - GGML_ASSERT(false); + GGML_ASSERT(false && "not implemented"); #endif - } else if (tokens) { - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_name(embd, "embd"); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); - } else { + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } @@ -1381,20 +1379,20 @@ static bool llama_eval_internal( offload_func_t offload_func_v = llama_nop; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers; - } + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers; + } #endif // GGML_USE_CUBLAS - // EMM TODO distribute work more evenly - maybe rank=0 gets the smallest amount? - int slice_size = (n_layer + (lctx.mpi_size - 1)) / lctx.mpi_size; - for (int il = lctx.mpi_rank * slice_size; il < n_layer && il < (lctx.mpi_rank + 1) * slice_size; ++il) { + for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); + offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS @@ -1601,7 +1599,6 @@ static bool llama_eval_internal( // input for next layer inpL = cur; - } lctx.use_buf(ctx0, 0); @@ -1609,45 +1606,24 @@ static bool llama_eval_internal( // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; - if (lctx.mpi_size > 1) { -#ifdef GGML_USE_MPI - cur = ggml_mpi_send_tensor(ctx0, cur, (lctx.mpi_rank+1)%lctx.mpi_size); - ggml_set_name(cur, "mpi_send"); -#else - GGML_ASSERT(false); -#endif - } - if (lctx.mpi_rank == 0) { - if (lctx.mpi_size > 1) { -#ifdef GGML_USE_MPI - cur = ggml_mpi_recv_tensor(ctx0, cur, - ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N), - lctx.mpi_size-1); - ggml_set_name(cur, "mpi_recv"); -#else - GGML_ASSERT(false); -#endif - } - // norm - { - cur = ggml_rms_norm(ctx0, cur); - offload_func_nr(cur); - ggml_set_name(cur, "rms_norm_2"); - - // cur = cur*norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.norm); - // offload_func_nr(cur); // TODO CPU + GPU mirrored backend - ggml_set_name(cur, "result_norm"); - - embeddings = cur; - } + // norm + { + cur = ggml_rms_norm(ctx0, inpL); + offload_func_nr(cur); + ggml_set_name(cur, "rms_norm_2"); + // cur = cur*norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.norm); + // offload_func_nr(cur); // TODO CPU + GPU mirrored backend + ggml_set_name(cur, "result_norm"); - // lm_head - cur = ggml_mul_mat(ctx0, model.output, cur); - ggml_set_name(cur, "result_output"); + embeddings = cur; } + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + lctx.use_buf(ctx0, -1); // logits -> probs @@ -1680,6 +1656,10 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); } +#elif GGML_USE_MPI + ggml_mpi_graph_compute(lctx.ctx_mpi, ctx0, &gf, n_layer); + + cur = gf.nodes[gf.n_nodes - 1]; #else ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); #endif @@ -1705,7 +1685,11 @@ static bool llama_eval_internal( // update kv token count lctx.kv_self.n = n_past + N; - if (lctx.mpi_rank == 0) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(lctx.ctx_mpi) == 0) { +#else + { +#endif // extract logits { auto & logits_out = lctx.logits; @@ -2676,14 +2660,6 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; -#ifdef GGML_USE_MPI - MPI_Comm_size(MPI_COMM_WORLD, &ctx->mpi_size); - MPI_Comm_rank(MPI_COMM_WORLD, &ctx->mpi_rank); -#else - ctx->mpi_size = 1; - ctx->mpi_rank = 0; -#endif - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; // reserve memory for context buffers @@ -2756,15 +2732,17 @@ struct llama_context * llama_new_context_with_model( } #endif - if (ctx->mpi_rank > 0) { - // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - const std::vector tmp = { llama_token_bos(), }; - while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)); #ifdef GGML_USE_MPI - MPI_Finalize(); -#endif + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos()); + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); exit(1); } +#endif return ctx; } @@ -3443,13 +3421,6 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { -#ifdef GGML_USE_MPI - // Synchronize the worker node parameters with the root node - MPI_Barrier(MPI_COMM_WORLD); - MPI_Bcast(&n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(&n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(&n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); -#endif if (!llama_eval_internal(*ctx, tokens, nullptr, n_tokens, n_past, n_threads, nullptr)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; diff --git a/llama.h b/llama.h index b90c523555da8..686463aa25af8 100644 --- a/llama.h +++ b/llama.h @@ -158,9 +158,9 @@ extern "C" { // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations // Call once at the start of the program - LLAMA_API void llama_init_backend(bool numa); + LLAMA_API void llama_backend_init(bool numa); // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_finalize_backend(); + LLAMA_API void llama_backend_free(); LLAMA_API int64_t llama_time_us();