diff --git a/train_gpt2.cu b/train_gpt2.cu index 1c7d04e39..9ddee27fc 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -194,10 +194,15 @@ void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) { #endif } -void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) { +void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config, float *unified_buffer) { #ifdef MULTI_GPU if (multi_gpu_config->num_processes > 1) { - mpiCheck(MPI_Barrier(MPI_COMM_WORLD)); + if (!multi_gpu_config->slurm_managed) { // dummy nccl call to sync process + ncclCheck(ncclAllReduce(unified_buffer, unified_buffer, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, 0)); + } + else { + mpiCheck(MPI_Barrier(MPI_COMM_WORLD)); + } } #endif } @@ -1763,13 +1768,13 @@ int main(int argc, char *argv[]) { snprintf(filename_buffer, 512, "%s/state_%08d_%05d.bin", output_log_dir, step, multi_gpu_config.process_rank); save_state(filename_buffer, step, &model, &train_loader); // DONE file is a signal that this checkpoint as a whole is complete - multi_gpu_barrier(&multi_gpu_config); + multi_gpu_barrier(&multi_gpu_config, model.unified_buffer); if (multi_gpu_config.process_rank == 0) { snprintf(filename_buffer, 512, "%s/DONE_%08d", output_log_dir, step); FILE* done_file = fopenCheck(filename_buffer, "w"); fclose(done_file); } - multi_gpu_barrier(&multi_gpu_config); + multi_gpu_barrier(&multi_gpu_config, model.unified_buffer); } resuming = 0;