Skip to content

Commit

Permalink
xm.save() should not set sync_xla_data=True when sync'ing. (#8484) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcuiaws authored Dec 19, 2024
1 parent de9a01e commit ef85771
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
22 changes: 22 additions & 0 deletions test/test_input_output_aliases.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys

import torch
Expand Down Expand Up @@ -162,6 +163,27 @@ def test_separate_graphs(self):

self.assertEqual(t1.item(), 3)

def test_xm_save_no_aliasing(self):
"""
Test that xm.save() does not perform aliasing.
"""
xla_device = xm.xla_device()
t0 = torch.tensor([1], device=xla_device)
t1 = torch.tensor([2], device=xla_device)
xm.mark_step()

t2 = t0 + t1
t1.add_(1)

# Save the new value of t1 should not result in the old value
# being donated...
xm.save(t1, os.devnull)

# otherwise this mark_step could crash, or compute the wrong value
# for t2.
xm.mark_step()

self.assertEqual(t2.item(), 3)

if __name__ == '__main__':
test = unittest.main()
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ def _maybe_convert_to_cpu(data: Any, convert: bool = True) -> ToXlaTensorArena:

def convert_fn(tensors):
torch_xla._XLAC._xla_sync_multi(
tensors, devices=[], wait=True, sync_xla_data=True)
tensors, devices=[], wait=True, sync_xla_data=False)
if not convert:
return tensors
return torch_xla._XLAC._xla_get_cpu_tensors(tensors)
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _rewrite_data(path, data, save_tensors):

def convert_fn(tensors):
torch_xla._XLAC._xla_sync_multi(
tensors, devices=[], wait=True, sync_xla_data=True)
tensors, devices=[], wait=True, sync_xla_data=False)
rewritten_tensors = []
for i, t in enumerate(tensors):
if save_tensors:
Expand Down

0 comments on commit ef85771

Please sign in to comment.