Skip to content

Commit

Permalink
[BugFix, Performance] Fewer imports at root (#682)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 19, 2024
1 parent b43baa8 commit 8485755
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
14 changes: 1 addition & 13 deletions tensordict/memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,8 @@

from tensordict.utils import implement_for

from torch import distributed as dist

from torch.multiprocessing.reductions import ForkingPickler

try:
if dist.is_available():
from torch.distributed._tensor.api import DTensor
else:
raise ImportError
except ImportError:

class DTensor(torch.Tensor): # noqa: D101
...


class MemoryMappedTensor(torch.Tensor):
"""A Memory-mapped Tensor.
Expand Down Expand Up @@ -204,7 +192,7 @@ def from_tensor(
out.index = None
out.parent_shape = input.shape
if copy_data:
if isinstance(input, DTensor):
if hasattr(input, "full_tensor"):
input = input.full_tensor()
out.copy_(input)
return out
Expand Down
25 changes: 16 additions & 9 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,7 @@
)
from torch import multiprocessing as mp

H5_ERR = None
try:
import h5py

_has_h5 = True
except ModuleNotFoundError as err:
H5_ERR = err
_has_h5 = False
_has_h5 = importlib.util.find_spec("h5py", None) is not None


class _Visitor:
Expand Down Expand Up @@ -151,7 +144,9 @@ def __init__(
self._locked_tensordicts = []
self._lock_id = set()
if not _has_h5:
raise ModuleNotFoundError("Could not load h5py.") from H5_ERR
raise ModuleNotFoundError("Could not load h5py.")
import h5py

super().__init__()
self.filename = filename
self.mode = mode
Expand Down Expand Up @@ -213,6 +208,8 @@ def from_dict(cls, input_dict, filename, batch_size=None, device=None, **kwargs)
A :class:`PersitentTensorDict` instance linked to the newly created file.
"""
import h5py

file = h5py.File(filename, "w", locking=cls.LOCKING)
_has_batch_size = True
if batch_size is None:
Expand Down Expand Up @@ -260,6 +257,8 @@ def _get_array(self, key, default=NO_DEFAULT):
raise KeyError(f"key {key} not found in PersistentTensorDict {self}")

def _process_array(self, key, array):
import h5py

if isinstance(array, (h5py.Dataset,)):
if self.device is not None:
device = self.device
Expand Down Expand Up @@ -294,6 +293,8 @@ def get(self, key, default=NO_DEFAULT):
def get_at(
self, key: NestedKey, idx: IndexType, default: CompatibleType = NO_DEFAULT
) -> CompatibleType:
import h5py

array = self._get_array(key, default)
if isinstance(array, (h5py.Dataset,)):
if self.device is not None:
Expand Down Expand Up @@ -339,6 +340,8 @@ def _get_metadata(self, key):
This method avoids creating a tensor from scratch, and just reads the metadata of the array.
"""
import h5py

array = self._get_array(key)
if (
isinstance(array, (h5py.Dataset,))
Expand Down Expand Up @@ -1048,6 +1051,8 @@ def _set_metadata(self, orig_metadata_container: PersistentTensorDict):
self._nested_tensordicts[key]._set_metadata(td)

def _clone(self, recurse: bool = True, newfile=None) -> PersistentTensorDict:
import h5py

if recurse:
# this should clone the h5 to a new location indicated by newfile
if newfile is None:
Expand Down Expand Up @@ -1106,6 +1111,8 @@ def __getstate__(self):
return state

def __setstate__(self, state):
import h5py

state["file"] = h5py.File(
state["filename"], mode=state["mode"], locking=self.LOCKING
)
Expand Down

0 comments on commit 8485755

Please sign in to comment.