Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] from_modules method for MOE / ensemble learning #677

Merged
merged 5 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 135 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_td_fields,
_unravel_key_to_tuple,
as_decorator,
Buffer,
cache,
convert_ellipsis_to_idx,
DeviceType,
Expand All @@ -64,6 +65,7 @@
lock_blocked,
NestedKey,
prod,
set_lazy_legacy,
TensorDictFuture,
unravel_key,
unravel_key_list,
Expand Down Expand Up @@ -376,6 +378,7 @@ def from_module(
"""Copies the params and buffers of a module in a tensordict.

Args:
module (nn.Module): the module to get the parameters from.
as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams`
instance will be returned which can be used to store parameters
within a :class:`torch.nn.Module`. Defaults to ``False``.
Expand All @@ -385,7 +388,7 @@ def from_module(
module will be used and unflattened into a TensorDict with
the tree structure of the model. Defaults to ``False``.
.. note::
This is particularily useful when state-dict hooks have to be
This is particularly useful when state-dict hooks have to be
used.

Examples:
Expand All @@ -405,6 +408,137 @@ def from_module(
"""
...

@classmethod
def from_modules(
cls,
*modules,
as_module: bool = False,
lock: bool = True,
use_state_dict: bool = False,
lazy_stack: bool = False,
):
"""Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap.

Args:
modules (sequence of nn.Module): the modules to get the parameters from.
If the modules differ in their structure, a lazy stack is needed
(see the ``lazy_stack`` argument below).

Keyword Args:
as_module (bool, optional): if ``True``, a :class:`~tensordict.nn.TensorDictParams`
instance will be returned which can be used to store parameters
within a :class:`torch.nn.Module`. Defaults to ``False``.
lock (bool, optional): if ``True``, the resulting tensordict will be locked.
Defaults to ``True``.
use_state_dict (bool, optional): if ``True``, the state-dict from the
module will be used and unflattened into a TensorDict with
the tree structure of the model. Defaults to ``False``.
.. note::
This is particularly useful when state-dict hooks have to be
used.
lazy_stack (bool, optional): whether parameters should be densly or
lazily stacked. Defaults to ``False`` (dense stack).

.. note:: ``lazy_stack`` and ``as_module`` are exclusive features.

.. warning::
There is a crucial difference between lazy and non-lazy outputs
in that non-lazy output will reinstantiate parameters with the
desired batch-size, while ``lazy_stack`` will just represent
the parameters as lazily stacked. This means that whilst the
original parameters can safely be passed to an optimizer
when ``lazy_stack=True``, the new parameters need to be passed
when it is set to ``True``.

.. warning::
Whilst it can be tempting to use a lazy stack to keep the
orignal parameter references, remember that lazy stack
perform a stack each time :meth:`~.get` is called. This will
require memory (N times the size of the parameters, more if a
graph is built) and time to be computed.
It also means that the optimizer(s) will contain more
parameters, and operations like :meth:`~torch.optim.Optimizer.step`
or :meth:`~torch.optim.Optimizer.zero_grad` will take longer
to be executed. In general, ``lazy_stack`` should be reserved
to very few use cases.

Examples:
>>> from torch import nn
>>> from tensordict import TensorDict
>>> torch.manual_seed(0)
>>> empty_module = nn.Linear(3, 4, device="meta")
>>> n_models = 2
>>> modules = [nn.Linear(3, 4) for _ in range(n_models)]
>>> params = TensorDict.from_modules(*modules)
>>> print(params)
TensorDict(
fields={
bias: Parameter(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Parameter(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False)
>>> # example of batch execution
>>> def exec_module(params, x):
... with params.to_module(empty_module):
... return empty_module(x)
>>> x = torch.randn(3)
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> # since lazy_stack = False, backprop leaves the original params untouched
>>> y.sum().backward()
>>> assert params["weight"].grad.norm() > 0
>>> assert modules[0].weight.grad is None

With ``lazy_stack=True``, things are slightly different:

>>> params = TensorDict.from_modules(*modules, lazy_stack=True)
>>> print(params)
LazyStackedTensorDict(
fields={
bias: Tensor(shape=torch.Size([2, 4]), device=cpu, dtype=torch.float32, is_shared=False),
weight: Tensor(shape=torch.Size([2, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
exclusive_fields={
},
batch_size=torch.Size([2]),
device=None,
is_shared=False,
stack_dim=0)
>>> # example of batch execution
>>> y = torch.vmap(exec_module, (0, None))(params, x)
>>> assert y.shape == (n_models, 4)
>>> y.sum().backward()
>>> assert modules[0].weight.grad is not None


"""
param_list = [
cls.from_module(module, use_state_dict=use_state_dict) for module in modules
]
if lazy_stack:
from tensordict._lazy import LazyStackedTensorDict

params = LazyStackedTensorDict.lazy_stack(param_list)
else:
with set_lazy_legacy(False), torch.no_grad():
params = torch.stack(param_list)

# Make sure params are params, buffers are buffers
def make_param(param, orig_param):
if isinstance(orig_param, nn.Parameter):
return nn.Parameter(param.detach(), orig_param.requires_grad)
return Buffer(param)

params = params.apply(make_param, param_list[0])

if as_module:
from tensordict.nn import TensorDictParams

params = TensorDictParams(params, no_convert=True)
if lock:
params.lock_()
return params

@as_decorator()
def to_module(
self,
Expand Down
37 changes: 37 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,43 @@ def remover(module, *args, **kwargs):
sd = net.state_dict()
assert_allclose_td(params_sd.flatten_keys("."), TensorDict(sd, []))

@pytest.mark.parametrize("as_module", [False, True])
@pytest.mark.parametrize("lazy_stack", [False, True])
def test_from_modules(self, as_module, lazy_stack):
empty_module = nn.Linear(3, 4, device="meta")
modules = [nn.Linear(3, 4) for _ in range(3)]
if as_module and lazy_stack:
with pytest.raises(RuntimeError, match="within a TensorDictParams"):
params = TensorDict.from_modules(
*modules, as_module=as_module, lazy_stack=lazy_stack
)
return

params = TensorDict.from_modules(
*modules, as_module=as_module, lazy_stack=lazy_stack
)

def vmap_module(params, x):
with params.to_module(empty_module):
return empty_module(x)

x = torch.zeros(3)
y = torch.vmap(vmap_module, (0, None))(params, x)
y.sum().backward()
if lazy_stack:
leaves = []

def get_leaf(leaf):
leaves.append(leaf)

params.apply(get_leaf)
assert all(param.grad is not None for param in leaves)
assert all(param.grad is None for param in params.values(True, True))
else:
for p in modules[0].parameters():
assert p.grad is None
assert all(param.grad is not None for param in params.values(True, True))

@pytest.mark.parametrize(
"idx",
[
Expand Down
Loading