Skip to content

Commit

Permalink
enable empty cache on XPU device (#438)
Browse files Browse the repository at this point in the history
  • Loading branch information
ys950902 authored Aug 26, 2024
1 parent 4f9f1f6 commit 8e9d973
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def train_step(forward_step_func, data_iterator,

# Empty unused memory.
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
get_accelerator().empty_cache()

# Reduce gradients.
if not args.deepspeed:
Expand Down Expand Up @@ -782,7 +782,7 @@ def train_step(forward_step_func, data_iterator,

# Empty unused memory.
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
get_accelerator().empty_cache()

if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Average loss across microbatches.
Expand Down Expand Up @@ -1438,7 +1438,7 @@ def evaluate(forward_step_func,

# Empty unused memory
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
get_accelerator().empty_cache()

if mpu.is_pipeline_last_stage(ignore_virtual=True):
# Reduce across processes.
Expand Down

0 comments on commit 8e9d973

Please sign in to comment.