Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 15, 2024
1 parent 66d5a00 commit 5daa49a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 38 deletions.
60 changes: 45 additions & 15 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import warnings
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -46,8 +47,8 @@ class A2CLoss(LossModule):
https://arxiv.org/abs/1602.01783v2
Args:
actor (ProbabilisticTensorDictSequential): policy operator.
critic (ValueOperator): value operator.
actor_network (ProbabilisticTensorDictSequential): policy operator.
critic_network (ValueOperator): value operator.
entropy_bonus (bool): if ``True``, an entropy bonus will be added to the
loss to favour exploratory policies.
samples_mc_entropy (int): if the distribution retrieved from the policy
Expand Down Expand Up @@ -221,8 +222,8 @@ class _AcceptedKeys:

def __init__(
self,
actor: ProbabilisticTensorDictSequential,
critic: TensorDictModule,
actor_network: ProbabilisticTensorDictSequential,
critic_network: TensorDictModule,
*,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
Expand All @@ -233,23 +234,44 @@ def __init__(
separate_losses: bool = False,
advantage_key: str = None,
value_target_key: str = None,
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
):
if actor is not None:
actor_network = actor
del actor
if critic is not None:
critic_network = critic
del critic

self._out_keys = None
super().__init__()
self._set_deprecated_ctor_keys(
advantage=advantage_key, value_target=value_target_key
)

self.convert_to_functional(
actor, "actor", funs_to_decorate=["forward", "get_dist"]
)
self.functional = functional
if functional:
self.convert_to_functional(
actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"]
)
else:
self.actor_network = actor_network

if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor.parameters())
else:
policy_params = None
self.convert_to_functional(critic, "critic", compare_against=policy_params)
if functional:
self.convert_to_functional(
critic_network, "critic_network", compare_against=policy_params
)
else:
self.critic_network = critic_network

self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus and entropy_coef

Expand All @@ -265,15 +287,19 @@ def __init__(
self.gamma = gamma
self.loss_critic_type = loss_critic_type

@property
def actor(self):
return self.actor_network

@property
def in_keys(self):
keys = [
self.tensor_keys.action,
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor.in_keys,
*[("next", key) for key in self.actor.in_keys],
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
]
if self.critic_coef:
keys.extend(self.critic.in_keys)
Expand Down Expand Up @@ -326,9 +352,11 @@ def _log_probs(
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} require grad."
)
tensordict_clone = tensordict.select(*self.actor.in_keys).clone()
with self.actor_params.to_module(self.actor):
dist = self.actor.get_dist(tensordict_clone)
tensordict_clone = tensordict.select(*self.actor_network.in_keys).clone()
with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)
log_prob = dist.log_prob(action)
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist
Expand All @@ -339,7 +367,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
# overhead that we could easily reduce.
target_return = tensordict.get(self.tensor_keys.value_target)
tensordict_select = tensordict.select(*self.critic.in_keys)
with self.critic_params.to_module(self.critic):
with self.critic_params.to_module(
self.critic
) if self.functional else contextlib.nullcontext():
state_value = self.critic(
tensordict_select,
).get(self.tensor_keys.value)
Expand Down Expand Up @@ -407,7 +437,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
elif value_type == ValueEstimators.VTrace:
# VTrace currently does not support functional call on the actor
actor_with_params = repopulate_module(
deepcopy(self.actor), self.actor_params
deepcopy(self.actor_network), self.actor_network_params
)
self._value_estimator = VTrace(
value_network=self.critic, actor_network=actor_with_params, **hp
Expand Down
64 changes: 41 additions & 23 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class PPOLoss(LossModule):
https://arxiv.org/abs/1707.06347
Args:
actor (ProbabilisticTensorDictSequential): policy operator.
critic (ValueOperator): value operator.
actor_network (ProbabilisticTensorDictSequential): policy operator.
critic_network (ValueOperator): value operator.
Keyword Args:
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
Expand Down Expand Up @@ -259,8 +259,8 @@ class _AcceptedKeys:

def __init__(
self,
actor: ProbabilisticTensorDictSequential,
critic: TensorDictModule,
actor_network: ProbabilisticTensorDictSequential = None,
critic_network: TensorDictModule = None,
*,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
Expand All @@ -273,18 +273,30 @@ def __init__(
advantage_key: str = None,
value_target_key: str = None,
value_key: str = None,
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
):
if actor is not None:
actor_network = actor
del actor
if critic is not None:
critic_network = critic
del critic

self._in_keys = None
self._out_keys = None
super().__init__()
self.convert_to_functional(actor, "actor")
self.convert_to_functional(actor_network, "actor_network")
if separate_losses:
# we want to make sure there are no duplicates in the params: the
# params of critic must be refs to actor if they're shared
policy_params = list(actor.parameters())
else:
policy_params = None
self.convert_to_functional(critic, "critic", compare_against=policy_params)
self.convert_to_functional(
critic_network, "critic_network", compare_against=policy_params
)
self.samples_mc_entropy = samples_mc_entropy
self.entropy_bonus = entropy_bonus
self.separate_losses = separate_losses
Expand Down Expand Up @@ -314,9 +326,9 @@ def _set_in_keys(self):
("next", self.tensor_keys.reward),
("next", self.tensor_keys.done),
("next", self.tensor_keys.terminated),
*self.actor.in_keys,
*[("next", key) for key in self.actor.in_keys],
*self.critic.in_keys,
*self.actor_network.in_keys,
*[("next", key) for key in self.actor_network.in_keys],
*self.critic_network.in_keys,
]
self._in_keys = list(set(keys))

Expand Down Expand Up @@ -378,8 +390,8 @@ def _log_weight(
f"tensordict stored {self.tensor_keys.action} requires grad."
)

with self.actor_params.to_module(self.actor):
dist = self.actor.get_dist(tensordict)
with self.actor_network_params.to_module(self.actor_network):
dist = self.actor_network.get_dist(tensordict)
log_prob = dist.log_prob(action)

prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
Expand All @@ -405,8 +417,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
f"can be used for the value loss."
)

with self.critic_params.to_module(self.critic):
state_value_td = self.critic(tensordict)
with self.critic_network_params.to_module(self.critic_network):
state_value_td = self.critic_network(tensordict)

try:
state_value = state_value_td.get(self.tensor_keys.value)
Expand All @@ -426,7 +438,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
@property
@_cache_values
def _cached_critic_params_detached(self):
return self.critic_params.detach()
return self.critic_network_params.detach()

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
Expand Down Expand Up @@ -465,20 +477,26 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
hp["gamma"] = self.gamma
hp.update(hyperparams)
if value_type == ValueEstimators.TD1:
self._value_estimator = TD1Estimator(value_network=self.critic, **hp)
self._value_estimator = TD1Estimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.TD0:
self._value_estimator = TD0Estimator(value_network=self.critic, **hp)
self._value_estimator = TD0Estimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.GAE:
self._value_estimator = GAE(value_network=self.critic, **hp)
self._value_estimator = GAE(value_network=self.critic_network, **hp)
elif value_type == ValueEstimators.TDLambda:
self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp)
self._value_estimator = TDLambdaEstimator(
value_network=self.critic_network, **hp
)
elif value_type == ValueEstimators.VTrace:
# VTrace currently does not support functional call on the actor
actor_with_params = repopulate_module(
deepcopy(self.actor), self.actor_params
deepcopy(self.actor_network), self.actor_network_params
)
self._value_estimator = VTrace(
value_network=self.critic, actor_network=actor_with_params, **hp
value_network=self.critic_network, actor_network=actor_with_params, **hp
)
else:
raise NotImplementedError(f"Unknown value type {value_type}")
Expand Down Expand Up @@ -859,9 +877,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
log_weight, dist = self._log_weight(tensordict)
neg_loss = log_weight.exp() * advantage

previous_dist = self.actor.build_dist_from_params(tensordict)
with self.actor_params.to_module(self.actor):
current_dist = self.actor.get_dist(tensordict)
previous_dist = self.actor_network.build_dist_from_params(tensordict)
with self.actor_network_params.to_module(self.actor_network):
current_dist = self.actor_network.get_dist(tensordict)
try:
kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist)
except NotImplementedError:
Expand Down

0 comments on commit 5daa49a

Please sign in to comment.