Skip to content

Commit

Permalink
[BugFix] Fix non-blocking arg in copy_ (#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 4, 2023
1 parent 86c239f commit a25b22b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 8 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,8 +2126,14 @@ def _create_nested_tuple(self, key):
if len(key) > 1:
td._create_nested_tuple(key[1:])

def copy_(self, tensordict: T) -> T:
"""See :obj:`TensorDictBase.update_`."""
def copy_(self, tensordict: T, non_blocking: bool = None) -> T:
"""See :obj:`TensorDictBase.update_`.
The non-blocking argument will be ignored and is just present for
compatibility with :func:`torch.Tensor.copy_`.
"""
if non_blocking is False:
raise ValueError("non_blocking=False isn't supported in TensorDict.")
return self.update_(tensordict)

def copy_at_(self, tensordict: T, idx: IndexType) -> T:
Expand Down
4 changes: 3 additions & 1 deletion tensordict/memmap_deprec.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,9 @@ def cuda(self) -> torch.Tensor:
def numpy(self) -> np.ndarray:
return self._tensor.numpy()

def copy_(self, other: torch.Tensor | MemmapTensor) -> MemmapTensor:
def copy_(
self, other: torch.Tensor | MemmapTensor, non_blocking: bool = False
) -> MemmapTensor:
if isinstance(other, MemmapTensor) and other.filename == self.filename:
if not self.shape == other.shape:
raise ValueError(
Expand Down

0 comments on commit a25b22b

Please sign in to comment.