Skip to content

Commit

Permalink
fix guard of torch.cuda based on device type
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 17, 2024
1 parent 136b65e commit 6104ab1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ def get_lr(it):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
torch.cuda.synchronize() # wait for the GPU to finish work
if device_type == "cuda":
torch.cuda.synchronize() # wait for the GPU to finish work
t1 = time.time()
dt = t1 - t0 # time difference in seconds
tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
Expand Down

0 comments on commit 6104ab1

Please sign in to comment.