Skip to content

Commit

Permalink
[BugFix, Performance] Fewer imports at root (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 26, 2024
1 parent b234ead commit 17fd94c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
14 changes: 1 addition & 13 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,8 @@

from tensordict.utils import implement_for

from torch import distributed as dist

from torch.multiprocessing.reductions import ForkingPickler

try:
if dist.is_available():
from torch.distributed._tensor.api import DTensor
else:
raise ImportError
except ImportError:

class DTensor(torch.Tensor): # noqa: D101
...


class MemoryMappedTensor(torch.Tensor):
"""A Memory-mapped Tensor.
Expand Down Expand Up @@ -204,7 +192,7 @@ def from_tensor(
out.index = None
out.parent_shape = input.shape
if copy_data:
if isinstance(input, DTensor):
if hasattr(input, "full_tensor"):
input = input.full_tensor()
out.copy_(input)
return out
Expand Down
19 changes: 18 additions & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
NestedKey,
NUMPY_TO_TORCH_DTYPE_DICT,
)
from torch import multiprocessing as mp

_has_h5 = importlib.util.find_spec("h5py", None) is not None


class _Visitor:
Expand Down Expand Up @@ -142,7 +145,9 @@ def __init__(
self._locked_tensordicts = []
self._lock_id = set()
if not _has_h5:
raise ModuleNotFoundError("Could not load h5py.") from H5_ERR
raise ModuleNotFoundError("Could not load h5py.")
import h5py

super().__init__()
self.filename = filename
self.mode = mode
Expand Down Expand Up @@ -204,6 +209,8 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs)
A :class:`PersitentTensorDict` instance linked to the newly created file.
"""
import h5py

file = h5py.File(filename, "w", locking=cls.LOCKING)
_has_batch_size = True
if batch_size is None:
Expand Down Expand Up @@ -251,6 +258,8 @@ def _get_array(self, key, default=NO_DEFAULT):
raise KeyError(f"key {key} not found in PersistentTensorDict {self}")

def _process_array(self, key, array):
import h5py

if isinstance(array, (h5py.Dataset,)):
if self.device is not None:
device = self.device
Expand Down Expand Up @@ -285,6 +294,8 @@ def get(self, key, default=NO_DEFAULT):
def get_at(
self, key: NestedKey, idx: IndexType, default: CompatibleType = NO_DEFAULT
) -> CompatibleType:
import h5py

array = self._get_array(key, default)
if isinstance(array, (h5py.Dataset,)):
if self.device is not None:
Expand Down Expand Up @@ -330,6 +341,8 @@ def _get_metadata(self, key):
This method avoids creating a tensor from scratch, and just reads the metadata of the array.
"""
import h5py

array = self._get_array(key)
if (
isinstance(array, (h5py.Dataset,))
Expand Down Expand Up @@ -971,6 +984,8 @@ def _set_metadata(self, orig_metadata_container: PersistentTensorDict):
self._nested_tensordicts[key]._set_metadata(td)

def _clone(self, recurse: bool = True, newfile=None) -> PersistentTensorDict:
import h5py

if recurse:
# this should clone the h5 to a new location indicated by newfile
if newfile is None:
Expand Down Expand Up @@ -1029,6 +1044,8 @@ def __getstate__(self):
return state

def __setstate__(self, state):
import h5py

state["file"] = h5py.File(
state["filename"], mode=state["mode"], locking=self.LOCKING
)
Expand Down

0 comments on commit 17fd94c

Please sign in to comment.