Skip to content

Commit

Permalink
[Feature] Support classmethods in tensorclass (#448)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 21, 2023
1 parent 1fc646e commit c3640b7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 15 deletions.
54 changes: 39 additions & 15 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __torch_function__(

cls.__init__ = _init_wrapper(cls.__init__)
cls._from_tensordict = classmethod(_from_tensordict_wrapper(expected_keys))
cls.from_tensordict = cls._from_tensordict
cls.__torch_function__ = classmethod(__torch_function__)
cls.__getstate__ = _getstate
cls.__setstate__ = _setstate
Expand All @@ -180,6 +181,13 @@ def __torch_function__(
cls.state_dict = _state_dict
cls.load_state_dict = _load_state_dict

for attr in TensorDict.__dict__.keys():
func = getattr(TensorDict, attr)
if (
inspect.ismethod(func) and func.__self__ is TensorDict
): # detects classmethods
setattr(cls, attr, _wrap_classmethod(cls, func))

cls.to_tensordict = _to_tensordict
cls.device = property(_device, _device_setter)
cls.batch_size = property(_batch_size, _batch_size_setter)
Expand Down Expand Up @@ -406,21 +414,7 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417
return wrapper


def _getattr(self, attr: str) -> Any:
"""Retrieve the value of an object's attribute, or a method output if attr is callable.
Args:
attr: name of the attribute to retrieve or function to compute
Returns:
value of the attribute, or a method output applied on the instance
"""
res = getattr(self._tensordict, attr)
if not callable(res):
return res
func = res

def _wrap_method(self, attr, func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
args = tuple(_arg_to_tensordict(arg) for arg in args)
Expand All @@ -443,6 +437,36 @@ def wrapped_func(*args, **kwargs):
return wrapped_func


def _wrap_classmethod(cls, func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
res = func.__get__(cls)(*args, **kwargs)
# res = func(*args, **kwargs)
if isinstance(res, TensorDictBase):
# create a new tensorclass from res and copy the metadata from self
return cls._from_tensordict(res)
return res

return wrapped_func


def _getattr(self, attr: str) -> Any:
"""Retrieve the value of an object's attribute, or a method output if attr is callable.
Args:
attr: name of the attribute to retrieve or function to compute
Returns:
value of the attribute, or a method output applied on the instance
"""
res = getattr(self._tensordict, attr)
if not callable(res):
return res
func = res
return _wrap_method(self, attr, func)


def _getitem(self, item: NestedKey) -> Any:
"""Retrieve the class object at the given index. Indexing will happen for nested tensors as well.
Expand Down
38 changes: 38 additions & 0 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,44 @@ class MyClass:
assert cmemmap.z == "foo"


def test_from_memmap(tmpdir):
td = TensorDict(
{
("a", "b", "c"): 1,
("a", "d"): 2,
},
[],
).expand(10)
td.memmap_(tmpdir)

@tensorclass
class MyClass:
a: TensorDictBase

tc = MyClass.load_memmap(tmpdir)
assert isinstance(tc.a, TensorDict)
assert tc.batch_size == torch.Size([10])


def test_from_dict():
td = TensorDict(
{
("a", "b", "c"): 1,
("a", "d"): 2,
},
[],
).expand(10)
d = td.to_dict()

@tensorclass
class MyClass:
a: TensorDictBase

tc = MyClass.from_dict(d)
assert isinstance(tc.a, TensorDict)
assert tc.batch_size == torch.Size([10])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit c3640b7

Please sign in to comment.