-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Q Ensembles #1344
Comments
Thanks for raising this Regarding basic implementation, I think we should keep a contiguous stack of params and use vmap over the Q-Value network, that should be faster than looping over the networks / params. Next you're raising the question of how to select or not the min value. We do that in a very custom fashion in the losses, but I agree that we should find a more generic way of doing so. If we get a table of state-action values like [[v_00, v_01],
[v_10, v_11]] for row-based actions 0, 1 and column-based networks 0, 1, how do you select an action using the policy = TensorDIctSequential(
EnsembleStateActionValue(...), # uses vmap, writes a table of values in the output tensordict
EnsembleQValueActor(...), # selects the action given some heuristic
) In the loss function, we can either pass the entire policy, which will work since TensorDIctSequential keeps track of intermediate values in the output tensordict (but risky: one could overwrite the actions in the tensordict) or just pass the Happy to sketch a solution in a notebook if that helps! |
Do you think it makes sense to break this into three modules, so we can utilize the default policy = TensorDictSequential(
Ensemble(in_keys=['observation'], out_keys=['ensemble_state_action_value']),
Reduce(in_keys=['ensemble_state_action_value'], out_keys=['state_value_action'] reduce_fn=lambda x, dim: x.min(dim=dim),
QValueActor(env.action_spec)
) This keeps ensemble/reduce more general as they could be useful outside of Q functions. |
I think that could work Here's how I would go about the vmap module from tensordict import TensorDict
from tensordict.nn import TensorDictModuleBase, TensorDictModule
from torch import nn
import torch
net = nn.Sequential(nn.Linear(10, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 128), nn.Tanh(), nn.Linear(128, 128))
module = TensorDictModule(net, in_keys=["in"], out_keys=["out"])
class VmapParamModule(TensorDictModuleBase):
def __init__(self, module, num_copies):
super().__init__()
params = TensorDict.from_module(module)
params = params.expand(num_copies).to_tensordict()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
self.params_td = params
self.params = nn.ParameterList(list(params.values(True, True)))
self.module = module
def forward(self, td):
return torch.vmap(self.module, (None, 0))(td, self.params_td)
vmap_module = VmapParamModule(module, 2)
td = TensorDict({"in": torch.randn(10)}, [])
vmap_module(td) This gives you a td: TensorDict(
fields={
in: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.float32, is_shared=False),
out: Tensor(shape=torch.Size([2, 128]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([2]),
device=None,
is_shared=False) |
Great thanks. One last thing: I'm struggling with a nice way to reinitialize the copied parameters. Calling Furthermore, by disconnecting the modules from the parameters via It's ugly, but perhaps something like the following is the best we can do? modules = [deepcopy(module) for _ in range(num_copies)]
[user_defined_reset_weights_fn_(m) for m in modules]
params_td = TensorDict({f"copy_{k}": TensorDict.from_module(modules[k]) for k in range(num_copies)}) |
Another solution would be to add a recursive class TensorDictModuleBase:
...
def reset_parameters(self):
self._reset_parameters(self)
def _reset_parameters(self, module: Union[TensorDictModuleBase, nn.Module]):
if isinstance(module, [TensorDictModuleBase, nn.Module]):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
else:
[self._reset_parameters(m) for m in module.children()] |
Motivation
Twin Q/ensemble Q functions are used in many RL algorithms and mitigate Q overestimation. My understanding is that TorchRL only deals with ensembles in the loss functions. This is fine for actor/critic methods since we only use the critics to compute actor loss. But for critic-only methods (e.g. DQN), we need the Q ensemble at sample collection time. Doing so would also simplify the loss functions for DDPG/SAC/REDQ/etc.
Solution
I would like to add ensemble Q function support to TorchRL, but I'm not sure on the best way to do this. I was thinking of creating a
TensorDictModuleEnsemble
intensordict_module.py
that could be used for more than just Q functions. The issue is that we essentially need twoforward
functions: one at sample-time to compute some reduce operation likemin
over the ensemble outputs, and one at training time that does not reduce, but rather does something likeso we can compute the loss for all Q functions. I'm not sure if there is a good way to tell a
TensorDictModule
whether it is in "sampling" or "training" mode.I think it also makes sense to provide an option to keep separate datasets for each model/Q function, e.g.
I'd like to avoid the for-loop here if possible, but I'm not sure how.
Additional context
Related to #876
Checklist
The text was updated successfully, but these errors were encountered: