Skip to content

Commit

Permalink
[BugFix] Faster to_module (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 7, 2024
1 parent ca92d20 commit 517300a
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 68 deletions.
4 changes: 2 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2438,7 +2438,7 @@ def _unsqueeze(self, dim):
masked_select = TensorDict.masked_select
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
_to_module = TensorDict._to_module


class _CustomOpTensorDict(TensorDictBase):
Expand Down Expand Up @@ -2959,7 +2959,7 @@ def _unsqueeze(self, dim):
masked_select = TensorDict.masked_select
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
_to_module = TensorDict._to_module
_apply_nest = TensorDict._apply_nest
_remove_batch_dim = TensorDict._remove_batch_dim
all = TensorDict.all
Expand Down
107 changes: 51 additions & 56 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import torch
from functorch import dim as ftdim

from tensordict.base import (
_ACCEPTED_CLASSES,
_default_is_leaf,
Expand Down Expand Up @@ -250,6 +251,10 @@ def _from_module(
use_state_dict: bool = False,
prefix="",
):
from tensordict.nn import TensorDictParams

if isinstance(module, TensorDictParams):
return module
destination = {}
if use_state_dict:
keep_vars = False
Expand Down Expand Up @@ -280,7 +285,7 @@ def _from_module(
if submodule is not None:
subtd = cls._from_module(
module=submodule,
as_module=as_module,
as_module=False,
use_state_dict=use_state_dict,
prefix=prefix + name + ".",
)
Expand All @@ -293,21 +298,21 @@ def _from_module(
return destination

def is_empty(self):
from tensordict import NonTensorData

for _, item in self._tensordict.items():
for item in self._tensordict.values():
# we need to check if item is empty
if (
_is_tensor_collection(type(item))
and not isinstance(item, NonTensorData)
and item.is_empty()
):
continue
return False
if _is_tensor_collection(type(item)):
if not item.is_empty():
return False
from tensordict.tensorclass import NonTensorData

if isinstance(item, NonTensorData):
return False
else:
return False
return True

@as_decorator()
def to_module(
def _to_module(
self,
module,
*,
Expand All @@ -317,28 +322,24 @@ def to_module(
memo=None,
use_state_dict: bool = False,
):

if not use_state_dict and isinstance(module, TensorDictBase):
if return_swap:
swap = module.copy()
module.update(self)
return swap
else:
module.update(self)
return

# we use __dict__ directly to avoid the getattr/setattr overhead whenever we can
__dict__ = module.__dict__

swap = None
has_set_device = False
if memo is None:
hooks = getattr(
torch.nn.modules.module, "_global_parameter_registration_hooks", {}
)
memo = {"hooks": tuple(hooks.values())}
else:
hooks = memo["hooks"]
hooks = memo["hooks"]
if return_swap:
# this could break if the device and batch-size are not congruent.
# For batch-size it is a minor issue (unlikely that a td with batch-size
# is passed with to_module) but for the device it could be a problem.
if swap_dest is None:
swap = TensorDict({}, batch_size=torch.Size(()), _run_checks=False)
else:
swap = swap_dest
memo[id(module)] = swap
_swap = {}
memo[id(module)] = _swap

if use_state_dict:
if inplace is not None:
raise RuntimeError(
Expand Down Expand Up @@ -370,7 +371,7 @@ def convert_type(x, y):
return Buffer(x)
return x

input = state_dict.unflatten_keys(".").apply(convert_type, self)
input = state_dict.unflatten_keys(".")._fast_apply(convert_type, self)
else:
input = self
inplace = bool(inplace)
Expand All @@ -397,44 +398,38 @@ def convert_type(x, y):
else:
if value.is_empty():
# if there is at least one key, we must populate the module.
# Otherwise we just go to the next key
# Otherwise, we just go to the next key
continue
if swap_dest is not None:
local_dest = swap_dest._get_str(key, default=NO_DEFAULT)
else:
local_dest = None
child = __dict__["_modules"][key]
if id(child) in memo:
local_out = memo[id(child)]
else:
local_out = value.to_module(
local_out = memo.get(id(child), NO_DEFAULT)
if local_out is NO_DEFAULT:
local_out = value._to_module(
child,
inplace=inplace,
return_swap=return_swap,
swap_dest=local_dest,
swap_dest={}, # we'll be calling update later
memo=memo,
use_state_dict=use_state_dict,
)
# we don't want to do this op more than once
if return_swap and (
not has_set_device
and swap.device is not None
and local_out.device is not None
and local_out.device != swap.device
):
has_set_device = True
# map out to the local_out device
swap = swap.to(device=local_out.device)

if return_swap:
_swap[key] = local_out
if return_swap:
if isinstance(swap, TensorDict):
# this is very ad-hoc but faster than calling _set_str every time
swap._tensordict.update(_swap)
if isinstance(swap_dest, dict):
return _swap
elif swap_dest is not None:

def _quick_set(swap_dict, swap_td):
for key, val in swap_dict.items():
if isinstance(val, dict):
_quick_set(val, swap_td._get_str(key, default=NO_DEFAULT))
else:
swap_td._set_str(key, val, inplace=False, validated=True)

_quick_set(_swap, swap_dest)
return swap_dest
else:
swap.update(_swap)
return swap
return TensorDict(_swap, batch_size=[], _run_checks=False)

def __ne__(self, other: object) -> T | bool:
if _is_tensorclass(other):
Expand Down Expand Up @@ -2829,7 +2824,7 @@ def _create_nested_str(self, key):
memmap_like = TensorDict.memmap_like
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind

def _view(self, *args, **kwargs):
Expand Down
33 changes: 27 additions & 6 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,6 @@ def from_module(
"""
...

@abc.abstractmethod
@as_decorator()
def to_module(
self,
Expand All @@ -414,8 +413,8 @@ def to_module(
inplace: bool | None = None,
return_swap: bool = True,
swap_dest=None,
memo=None,
use_state_dict: bool = False,
memo=None, # deprecated
):
"""Writes the content of a TensorDictBase instance onto a given nn.Module attributes, recursively.
Expand All @@ -429,10 +428,6 @@ def to_module(
will be returned. Defaults to ``False``.
swap_dest (TensorDictBase, optional): if ``return_swap`` is ``True``,
the tensordict where the swap should be written.
memo (dict, optional): when the same module is present multiple times
in the input module, a memo is used to avoid fetching the params
that have just been set. This argument should be ignored during
regular calls to `to_module`.
use_state_dict (bool, optional): if ``True``, state-dict API will be
used to load the parameters (including the state-dict hooks).
Defaults to ``False``.
Expand All @@ -447,6 +442,32 @@ def to_module(
>>> params.to_module(module)
>>> assert (module.layers[0].linear1.weight == 0).all()
"""
if memo is not None:
raise RuntimeError("memo cannot be passed to the public to_module anymore.")
hooks = getattr(
torch.nn.modules.module, "_global_parameter_registration_hooks", {}
)
memo = {"hooks": tuple(hooks.values())}
return self._to_module(
module=module,
inplace=inplace,
return_swap=return_swap,
swap_dest=swap_dest,
memo=memo,
use_state_dict=use_state_dict,
)

@abc.abstractmethod
def _to_module(
self,
module,
*,
inplace: bool | None = None,
return_swap: bool = True,
swap_dest=None,
memo=None,
use_state_dict: bool = False,
):
...

# Shape functionality
Expand Down
4 changes: 1 addition & 3 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from tensordict.utils import (
_LOCK_ERROR,
as_decorator,
Buffer,
erase_cache,
IndexType,
Expand Down Expand Up @@ -909,8 +908,7 @@ def split(self, split_size: int | list[int], dim: int = 0) -> list[TensorDictBas
...

@_fallback
@as_decorator()
def to_module(
def _to_module(
self,
module,
*,
Expand Down
2 changes: 1 addition & 1 deletion tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ def _unsqueeze(self, dim):
masked_select = TensorDict.masked_select
reshape = TensorDict.reshape
split = TensorDict.split
to_module = TensorDict.to_module
_to_module = TensorDict._to_module
_unbind = TensorDict._unbind
_get_names_idx = TensorDict._get_names_idx

Expand Down
36 changes: 36 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2972,6 +2972,42 @@ def test_tdparams_clone_tying(self):
td_clone = td.clone()
assert td_clone["c"] is td_clone["a", "b", "c"]

def test_func_on_tdparams(self):
# tdparams isn't represented in a nested way, so we must check that calling to_module on it works ok
net = nn.Sequential(
nn.Linear(2, 2),
nn.Sequential(
nn.Linear(2, 2),
nn.Dropout(),
nn.BatchNorm1d(2),
nn.Sequential(
nn.Tanh(),
nn.Linear(2, 2),
),
),
)

params = TensorDict.from_module(net, as_module=True)

params0 = params.apply(lambda x: x.data * 0)
assert (params0 == 0).all()
with params0.to_module(params):
assert (params == 0).all()
assert not (params == 0).all()

# Now with a module around it
class MyModule(nn.Module):
pass

m = MyModule()
m.params = params
params_m = TensorDict.from_module(m, as_module=True)
params_m0 = params_m.apply(lambda x: x.data * 0)
assert (params_m0 == 0).all()
with params_m0.to_module(m):
assert (params_m == 0).all()
assert not (params_m == 0).all()

def test_inplace_ops(self):
td = TensorDict(
{
Expand Down

0 comments on commit 517300a

Please sign in to comment.