Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Oct 3, 2024
1 parent 009abc3 commit 9798b2d
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10417,11 +10417,13 @@ def set_(x):
result._consolidated = {"storage": storage_cast}
if "metadata" in self._consolidated:
result._consolidated["metadata"] = deepcopy(self._consolidated["metadata"])
if not non_blocking:
if device.type == "cuda":
if non_blocking in (False, None):
if device.type == "cuda" and non_blocking is False:
# sending to CUDA force sync
cuda_device = device
elif storage.device.type == "cuda":
cuda_device = device
# sending from cuda: need sync unless intentionally not asked for
cuda_device = storage.device.type
else:
cuda_device = None
if cuda_device is not None:
Expand Down

0 comments on commit 9798b2d

Please sign in to comment.