Skip to content

Commit

Permalink
unified mem for multi_gpu_float_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
Chinthaka Gamanayakege committed Jun 5, 2024
1 parent 53288e9 commit df038ad
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ MultiGpuConfig multi_gpu_config_init(int *argc, char ***argv) {
fclose(idFile);
} else { // Other ranks wait until the file is available and read the unique ID
do {
usleep(1000000);
idFile = fopen(filename, "rb");
if (idFile != NULL) break;
usleep(100000);
} while (idFile == NULL);
fread(&nccl_id, sizeof(nccl_id), 1, idFile);
fclose(idFile);
Expand Down Expand Up @@ -493,7 +493,7 @@ typedef struct {
int* targets; // the target tokens for the current forward pass
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs
float* loss_buffer; // GPU buffer to avg loss across process
float* unified_buffer; // GPU buffer to avg loss across process
floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
float* cpu_losses_fp32; // same but fp32
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
Expand All @@ -515,7 +515,7 @@ void gpt2_init_common(GPT2 *model) {
model->targets = NULL;
model->cpu_losses = NULL;
model->cpu_losses_fp32 = NULL;
model->loss_buffer = NULL;
model->unified_buffer = NULL;
// the B,T params are determined and set, fixed on first batch in forward()
model->batch_size = 0;
model->seq_len = 0;
Expand Down Expand Up @@ -1042,19 +1042,20 @@ void gpt2_backward(GPT2 *model, int* inputs) {
}

// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
float multi_gpu_float_sum(float value, float *d_buffer, const MultiGpuConfig* multi_gpu_config) {
float multi_gpu_float_sum(float value, float *unified_buffer, const MultiGpuConfig* multi_gpu_config) {
#ifdef MULTI_GPU
if (multi_gpu_config->num_processes == 1) return value;

float result;
if (multi_gpu_config->slurm_managed) { //If the process is managed by slurm we don't need to use MPI
if (d_buffer == NULL) cudaCheck(cudaMalloc(&d_buffer, sizeof(float)));
cudaCheck(cudaMemcpy(d_buffer, &value, sizeof(float), cudaMemcpyHostToDevice));
ncclCheck(ncclAllReduce(d_buffer, d_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, 0));
cudaCheck(cudaMemcpy(&result, d_buffer, sizeof(float), cudaMemcpyDeviceToHost));
return result;
if (unified_buffer == NULL) cudaCheck(cudaMallocManaged(&unified_buffer, sizeof(float)));
*unified_buffer = value;
cudaCheck(cudaMemPrefetchAsync(unified_buffer, sizeof(float), multi_gpu_config->local_device_idx, 0));
ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, 0));
cudaCheck(cudaMemPrefetchAsync(unified_buffer, sizeof(float), cudaCpuDeviceId, 0));
return *unified_buffer;
}
// note MPI doesn't support all reduce with mean, only sum
float result;
mpiCheck(MPI_Allreduce(&value, &result, 1, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD));
return result;
#else
Expand All @@ -1070,7 +1071,7 @@ void gpt2_multi_gpu_loss_and_grad_reduce(GPT2* model, MultiGpuConfig* multi_gpu_
// If there's only one process, there is nothing to do
if (multi_gpu_config->num_processes == 1) { return; }
// Average all losses.
model->accumulated_mean_loss = multi_gpu_float_sum(model->mean_loss, model->loss_buffer, multi_gpu_config) / multi_gpu_config->num_processes;
model->accumulated_mean_loss = multi_gpu_float_sum(model->mean_loss, model->unified_buffer, multi_gpu_config) / multi_gpu_config->num_processes;
if(multi_gpu_config->zero_stage == 0) {
// no ZERO == standard DDP: Average all gradients.
ncclCheck(ncclAllReduce(model->grads_memory, model->grads_memory,
Expand Down Expand Up @@ -1136,7 +1137,7 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl
cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost));
if (multi_gpu_config->zero_stage == 1) {
// further sum the (partial) squared norm across all GPUs (see comment ^1 above)
grad_norm_squared_cpu = multi_gpu_cpu_float_sum(grad_norm_squared_cpu);
grad_norm_squared_cpu = multi_gpu_float_sum(grad_norm_squared_cpu, model->unified_buffer, multi_gpu_config);
}

if(!isfinite(grad_norm_squared_cpu)) {
Expand Down Expand Up @@ -1678,7 +1679,7 @@ int main(int argc, char *argv[]) {
val_loss += model.mean_loss;
}
val_loss /= val_num_batches;
val_loss = multi_gpu_float_sum(val_loss, model.loss_buffer, &multi_gpu_config) / multi_gpu_config.num_processes;
val_loss = multi_gpu_float_sum(val_loss, model.unified_buffer, &multi_gpu_config) / multi_gpu_config.num_processes;
printf0("val loss %f\n", val_loss);
logger_log_val(&logger, step, val_loss);
}
Expand All @@ -1697,7 +1698,7 @@ int main(int argc, char *argv[]) {
eval_acc_norm += (float)correct;
}
// careful because not all ranks may have the exact same allocation of number of examples
eval_acc_norm = multi_gpu_float_sum(eval_acc_norm, model.loss_buffer, &multi_gpu_config);
eval_acc_norm = multi_gpu_float_sum(eval_acc_norm, model.unified_buffer, &multi_gpu_config);
printf0("HellaSwag: %d/%d = %f\n", (int)eval_acc_norm, eval_loader.num_examples, eval_acc_norm / eval_loader.num_examples);
logger_log_eval(&logger, step, eval_acc_norm / eval_loader.num_examples);
}
Expand Down

0 comments on commit df038ad

Please sign in to comment.