Skip to content

Commit

Permalink
Merge pull request #1 from ggerganov/refactor-mpi
Browse files Browse the repository at this point in the history
mpi : trying to move more MPI stuff into ggml-mpi (ggerganov#2099)
  • Loading branch information
evanmiller authored Jul 9, 2023
2 parents 03cc12b + 81c5ddd commit 1c3a15c
Show file tree
Hide file tree
Showing 11 changed files with 292 additions and 171 deletions.
2 changes: 1 addition & 1 deletion examples/embd-input/embd-input-lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) {
}
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);

llama_init_backend(params.numa);
llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;
Expand Down
4 changes: 3 additions & 1 deletion examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}

llama_init_backend(params.numa);
llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;
Expand Down Expand Up @@ -93,5 +93,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);

llama_backend_free();

return 0;
}
4 changes: 2 additions & 2 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}

llama_init_backend(params.numa);
llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;
Expand Down Expand Up @@ -671,7 +671,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);

llama_finalize_backend();
llama_backend_free();

return 0;
}
4 changes: 2 additions & 2 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ int main(int argc, char ** argv) {
params.prompt = gpt_random_prompt(rng);
}

llama_init_backend(params.numa);
llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;
Expand All @@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);

llama_finalize_backend();
llama_backend_free();

return 0;
}
4 changes: 3 additions & 1 deletion examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ int main(int argc, char ** argv) {
usage(argv[0]);
}

llama_init_backend(false);
llama_backend_init(false);

// parse command line arguments
const std::string fname_inp = argv[arg_idx];
Expand Down Expand Up @@ -257,5 +257,7 @@ int main(int argc, char ** argv) {
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
}

llama_backend_free();

return 0;
}
4 changes: 3 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ int main(int argc, char **argv)
params.model_alias = params.model;
}

llama_init_backend(params.numa);
llama_backend_init(params.numa);

LOG_INFO("build info", {{"build", BUILD_NUMBER},
{"commit", BUILD_COMMIT}});
Expand Down Expand Up @@ -1309,5 +1309,7 @@ int main(int argc, char **argv)
return 1;
}

llama_backend_free();

return 0;
}
4 changes: 2 additions & 2 deletions examples/simple/simple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ int main(int argc, char ** argv)
// Init LLM :
//---------------------------------

llama_init_backend(params.numa);
llama_backend_init(params.numa);

llama_model * model;
llama_context * ctx;
Expand Down Expand Up @@ -173,7 +173,7 @@ int main(int argc, char ** argv)
llama_free( ctx );
llama_free_model( model );

llama_finalize_backend();
llama_backend_free();

return 0;
}
Expand Down
241 changes: 186 additions & 55 deletions ggml-mpi.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,80 +2,211 @@

#include "ggml.h"

#include <mpi.h>

#include <stdio.h>
#include <stdlib.h>
#include <mpi.h>

#define MIN(a, b) ((a) < (b) ? (a) : (b))

#define UNUSED GGML_UNUSED

struct ggml_mpi_tensor_info {
struct ggml_mpi_context {
int rank;
int size;
};

// ggml_compute_forward_send
void ggml_mpi_backend_init(void) {
MPI_Init(NULL, NULL);
}

static void ggml_mpi_compute_forward_send(
struct ggml_tensor * src,
const struct ggml_tensor * orig) {
UNUSED(orig);
GGML_ASSERT(src->type == GGML_TYPE_F32);
void ggml_mpi_backend_free(void) {
MPI_Finalize();
}

struct ggml_mpi_context * ggml_mpi_init(void) {
struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));

int my_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);

int dst_rank = ((struct ggml_mpi_tensor_info *)src->extra)->rank;
// fprintf(stderr, "(%d) Sending to (%d)\n", my_rank, (int)dst->extra);
int retval = MPI_Send(src->data, ggml_nelements(src), MPI_FLOAT, dst_rank, 0, MPI_COMM_WORLD);
// fprintf(stderr, "(%d) Sent to (%d)\n", my_rank, (int)dst->extra);
GGML_ASSERT(retval == MPI_SUCCESS);
return ctx;
}

// ggml_compute_forward_recv

static void ggml_mpi_compute_forward_recv(
struct ggml_tensor * dst,
const struct ggml_tensor * orig,
const struct ggml_tensor * parent) {
UNUSED(parent);
UNUSED(orig);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
MPI_Status status;

int my_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &my_rank);

int src_rank = ((struct ggml_mpi_tensor_info *)dst->extra)->rank;
// fprintf(stderr, "(%d) Receiving from (%d)\n", my_rank, src_extra);
int retval = MPI_Recv(dst->data, ggml_nelements(dst), MPI_FLOAT, src_rank, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
// fprintf(stderr, "(%d) Received from (%d)\n", my_rank, src_extra);
GGML_ASSERT(retval == MPI_SUCCESS);
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
free(ctx);
}

struct ggml_tensor * ggml_mpi_send_tensor(
struct ggml_context * ctx,
struct ggml_tensor *src,
int dst_rank) {
int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
return ctx->rank;
}

struct ggml_tensor * result = ggml_map_custom1_inplace_f32(ctx, src, ggml_mpi_compute_forward_send);
void ggml_mpi_eval_init(
struct ggml_mpi_context * ctx_mpi,
int * n_tokens,
int * n_past,
int * n_threads) {
UNUSED(ctx_mpi);

// TODO how/when to free this struct?
struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info));
info->rank = dst_rank;
result->extra = info;
// synchronize the worker node parameters with the root node
MPI_Barrier(MPI_COMM_WORLD);

MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
}

return result;
int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
if (t == NULL) {
fprintf(stderr, "%s: tensor %s not found\n", __func__, name);
return -1;
}

for (int i = 0; i < gf->n_nodes; i++) {
if (gf->nodes[i] == t) {
return i;
}
}

fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name);
return -1;
}

struct ggml_tensor * ggml_mpi_recv_tensor(
struct ggml_context * ctx,
struct ggml_tensor *parent,
struct ggml_tensor *dst,
int src_rank) {
struct ggml_tensor * result = ggml_map_custom2_inplace_f32(ctx, dst, parent, ggml_mpi_compute_forward_recv);
// TODO: there are many improvements that can be done to this implementation
void ggml_mpi_graph_compute(
struct ggml_mpi_context * ctx_mpi,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
int n_layers) {
const int mpi_rank = ctx_mpi->rank;
const int mpi_size = ctx_mpi->size;

struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens");
if (inp_tokens == NULL) {
fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
return;
}

struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0");
if (inp0 == NULL) {
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
return;
}

GGML_ASSERT(inp0 == gf->nodes[0]);

// distribute the compute graph into slices across the MPI nodes
//
// the main node (0) processes the last layers + the remainder of the compute graph
// and is responsible to pass the input tokens to the first node (1)
//
// node 1: [( 0) * n_per_node, ( 1) * n_per_node)
// node 2: [( 1) * n_per_node, ( 2) * n_per_node)
// ...
// node n-1: [(n-2) * n_per_node, (n-1) * n_per_node)
// node 0: [(n-1) * n_per_node, n_nodes)
//
if (mpi_rank > 0) {
if (mpi_rank == 1) { // the first node receives the input tokens from the main node
MPI_Status status; UNUSED(status);

const int mpi_rank_src = mpi_rank - 1;

const int retval = MPI_Recv(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
GGML_ASSERT(retval == MPI_SUCCESS);
} else { // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
MPI_Status status; UNUSED(status);

const int mpi_rank_src = mpi_rank - 1;

//printf("%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src);
const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
GGML_ASSERT(retval == MPI_SUCCESS);
}
} else if (mpi_size > 1) {
// node 0 sends the input tokens to node 1
{
const int mpi_rank_dst = mpi_rank + 1;

const int retval = MPI_Send(inp_tokens->data, ggml_nelements(inp_tokens), MPI_INT, mpi_rank_dst, 0, MPI_COMM_WORLD);
GGML_ASSERT(retval == MPI_SUCCESS);
}

// recv the output data from the last node
{
MPI_Status status; UNUSED(status);

const int mpi_rank_src = mpi_size - 1;

//fprintf(stderr, "%s: node %d: waiting for %d elements from %d\n", __func__, mpi_rank, (int) ggml_nelements(inp0), mpi_rank_src);
const int retval = MPI_Recv(inp0->data, ggml_nelements(inp0), MPI_FLOAT, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
GGML_ASSERT(retval == MPI_SUCCESS);
}
}

{
const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;

const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;

const int il0 = (mpi_idx + 0) * n_per_node;
const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node);

char name_l0[GGML_MAX_NAME];
char name_l1[GGML_MAX_NAME];

snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0);
snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1);

const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0);
const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes;

if (idx_l0 < 0 || idx_l1 < 0) {
fprintf(stderr, "%s: layer input nodes not found\n", __func__);
return;
}

// attach the input data to all nodes that need it
// TODO: not great - should be able to do this without modifying the compute graph (see next TODO below)
for (int i = idx_l0; i < idx_l1; i++) {
if (gf->nodes[i]->src0 == gf->nodes[idx_l0]) {
gf->nodes[i]->src0 = inp0;
}
if (gf->nodes[i]->src1 == gf->nodes[idx_l0]) {
gf->nodes[i]->src1 = inp0;
}
}

// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
for (int i = 1; i < idx_l1 - idx_l0; i++) {
gf->nodes[i] = gf->nodes[idx_l0 + i];
gf->grads[i] = gf->grads[idx_l0 + i];
}

// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
if (mpi_idx != 0) {
gf->nodes[0]->op = GGML_OP_NONE;
}

gf->n_nodes = idx_l1 - idx_l0;

//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
}

ggml_graph_compute(ctx, gf);

//fprintf(stderr, "%s: node %d: done\n", __func__, mpi_rank);

// send the output data to the next node
if (mpi_rank > 0) {
struct ggml_tensor * output = gf->nodes[gf->n_nodes - 1];

const int mpi_rank_dst = (mpi_rank + 1) % mpi_size;

// TODO how/when to free this struct?
struct ggml_mpi_tensor_info *info = calloc(1, sizeof(struct ggml_mpi_tensor_info));
info->rank = src_rank;
result->extra = info;
//fprintf(stderr, "%s: node %d: sending %d elements to node %d\n", __func__, mpi_rank, ggml_nelements(output), mpi_rank_dst);

return result;
const int retval = MPI_Send(output->data, ggml_nelements(output), MPI_FLOAT, mpi_rank_dst, 0, MPI_COMM_WORLD);
GGML_ASSERT(retval == MPI_SUCCESS);
}
}
Loading

0 comments on commit 1c3a15c

Please sign in to comment.