Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Q Ensembles #1344

Open
1 task done
smorad opened this issue Jul 1, 2023 · 5 comments
Open
1 task done

[Feature Request] Q Ensembles #1344

smorad opened this issue Jul 1, 2023 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@smorad
Copy link
Contributor

smorad commented Jul 1, 2023

Motivation

Twin Q/ensemble Q functions are used in many RL algorithms and mitigate Q overestimation. My understanding is that TorchRL only deals with ensembles in the loss functions. This is fine for actor/critic methods since we only use the critics to compute actor loss. But for critic-only methods (e.g. DQN), we need the Q ensemble at sample collection time. Doing so would also simplify the loss functions for DDPG/SAC/REDQ/etc.

Solution

I would like to add ensemble Q function support to TorchRL, but I'm not sure on the best way to do this. I was thinking of creating a TensorDictModuleEnsemble in tensordict_module.py that could be used for more than just Q functions. The issue is that we essentially need two forward functions: one at sample-time to compute some reduce operation like min over the ensemble outputs, and one at training time that does not reduce, but rather does something like

self.ensemble(tensordict.expand(tensordict.shape, ensemble_size))

so we can compute the loss for all Q functions. I'm not sure if there is a good way to tell a TensorDictModule whether it is in "sampling" or "training" mode.

I think it also makes sense to provide an option to keep separate datasets for each model/Q function, e.g.

tensordict.set("ensemble_idx", tensordict.get("ensemble_idx", torch.randint(ensemble_size, 1)))
for i in range(ensemble_size):
  sub_td = tensordict[tensordict["ensemble_idx"] == i]
  self.models[i](sub_td)
...

I'd like to avoid the for-loop here if possible, but I'm not sure how.

Additional context

Related to #876

Checklist

  • I have checked that there is no similar issue in the repo (required)
@smorad smorad added the enhancement New feature or request label Jul 1, 2023
@vmoens
Copy link
Contributor

vmoens commented Jul 3, 2023

Thanks for raising this
Seems like an interesting problem to tackle.

Regarding basic implementation, I think we should keep a contiguous stack of params and use vmap over the Q-Value network, that should be faster than looping over the networks / params.

Next you're raising the question of how to select or not the min value. We do that in a very custom fashion in the losses, but I agree that we should find a more generic way of doing so. If we get a table of state-action values like

[[v_00, v_01],
[v_10, v_11]]

for row-based actions 0, 1 and column-based networks 0, 1, how do you select an action using the min operator (knowing that you want the action with the maximum value)?
IMO we could build the backbone that produces the table and append an action selection strategy like we do for QValue networks (i'm trying to move away from wrappers as wrapping quickly puts things in super nested structures where you don't really know where your original module lives). So we'd have something like

policy = TensorDIctSequential(
  EnsembleStateActionValue(...), # uses vmap, writes a table of values in the output tensordict
  EnsembleQValueActor(...), # selects the action given some heuristic
)

In the loss function, we can either pass the entire policy, which will work since TensorDIctSequential keeps track of intermediate values in the output tensordict (but risky: one could overwrite the actions in the tensordict) or just pass the EnsembleStateActionValue and let the loss deal with it.

Happy to sketch a solution in a notebook if that helps!

@smorad
Copy link
Contributor Author

smorad commented Jul 4, 2023

Do you think it makes sense to break this into three modules, so we can utilize the default QValueActor?

policy = TensorDictSequential(
  Ensemble(in_keys=['observation'], out_keys=['ensemble_state_action_value']),
  Reduce(in_keys=['ensemble_state_action_value'], out_keys=['state_value_action'] reduce_fn=lambda x, dim: x.min(dim=dim),
  QValueActor(env.action_spec)
)

This keeps ensemble/reduce more general as they could be useful outside of Q functions.

@vmoens
Copy link
Contributor

vmoens commented Jul 4, 2023

I think that could work

Here's how I would go about the vmap module

from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase, TensorDictModule
from torch import nn
import torch

net = nn.Sequential(nn.Linear(10, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 128))

module = TensorDictModule(net, in_keys=["in"], out_keys=["out"])


class VmapParamModule(TensorDictModuleBase):
    def __init__(self, module, num_copies):
        super().__init__()
        params = TensorDict.from_module(module)
        params = params.expand(num_copies).to_tensordict()
        self.in_keys = module.in_keys
        self.out_keys = module.out_keys
        
        self.params_td = params
        self.params = nn.ParameterList(list(params.values(True, True)))
        self.module = module

    def forward(self, td):
        return torch.vmap(self.module, (None, 0))(td, self.params_td)

vmap_module = VmapParamModule(module, 2)
td = TensorDict({"in": torch.randn(10)}, [])
vmap_module(td)

This gives you a td:

TensorDict(
    fields={
        in: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        out: Tensor(shape=torch.Size([2, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

@smorad
Copy link
Contributor Author

smorad commented Jul 5, 2023

Great thanks. One last thing: I'm struggling with a nice way to reinitialize the copied parameters. Calling reset_parameters on each module individually would be ideal, but not all torch modules have this. For example, nn.Sequential does not expose reset_parameters.

Furthermore, by disconnecting the modules from the parameters via TensorDict.from_module, we can no longer call reset_parameters on each copy of the parameters. We could pass parameter_init_function to VMapParamModule, but this will become ugly fast. For example, consider a MLP with a special final layer init. The linear bias init requires fan_in which depends on the associated weight. The final layer of the MLP should do something like nn.init.normal_(weight, 0, 1e-4), bias.zero_(). So now we need to workout how to associate a weight with a bias and we also need to figure out which weight and bias belong to the final layer. I don't even wanna think about adding a CNN to the mix.

It's ugly, but perhaps something like the following is the best we can do?

modules = [deepcopy(module) for _ in range(num_copies)]
[user_defined_reset_weights_fn_(m) for m in modules]
params_td = TensorDict({f"copy_{k}": TensorDict.from_module(modules[k]) for k in range(num_copies)})

@smorad
Copy link
Contributor Author

smorad commented Jul 5, 2023

Another solution would be to add a recursive reset_parameters to TensorDictModuleBase, perhaps this would be useful elsewhere in torchrl? Then, we could reinitialize CNN/LSTM/MLP TensorDictModule parameters without having to write different code for each TensorDictModule.

class TensorDictModuleBase:
    ...
    def reset_parameters(self):
        self._reset_parameters(self) 

    def _reset_parameters(self, module: Union[TensorDictModuleBase, nn.Module]):
        if isinstance(module, [TensorDictModuleBase, nn.Module]):
            if hasattr(module, "reset_parameters"):
                module.reset_parameters()
            else:
                [self._reset_parameters(m) for m in module.children()]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants