Skip to content

Commit

Permalink
[Deprecations] Deprecate in view of v0.6 release
Browse files Browse the repository at this point in the history
ghstack-source-id: 0b40197a41f6fc95f15c41970a86a40b911c7d77
Pull Request resolved: #2446
  • Loading branch information
vmoens committed Oct 8, 2024
1 parent b116151 commit 8a99a31
Show file tree
Hide file tree
Showing 27 changed files with 140 additions and 253 deletions.
2 changes: 0 additions & 2 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -996,11 +996,9 @@ Helpers

RandomPolicy
check_env_specs
exploration_mode #deprecated
exploration_type
get_available_libraries
make_composite_from_td
set_exploration_mode #deprecated
set_exploration_type
step_mdp
terminated_or_truncated
Expand Down
6 changes: 3 additions & 3 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ Exploration wrappers and modules

To efficiently explore the environment, TorchRL proposes a series of modules
that will override the action sampled by the policy by a noisier version.
Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`:
if the exploration is set to ``"random"``, the exploration is active. In all
Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`:
if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all
other cases, the action written in the tensordict is simply the network output.

.. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule`
uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch.
The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on
The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on
this module.

.. currentmodule:: torchrl.modules
Expand Down
8 changes: 4 additions & 4 deletions examples/distributed/collectors/multi_nodes/ray_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
TransformedEnv,
)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, set_exploration_mode
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
Expand Down Expand Up @@ -85,8 +85,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.action_spec.space.low,
"max": env.action_spec.space.high,
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
)
Expand Down Expand Up @@ -201,7 +201,7 @@
stepcount_str = f"step count (max): {logs['step_count'][-1]}"
logs["lr"].append(optim.param_groups[0]["lr"])
lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
with set_exploration_mode("mean"), torch.no_grad():
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(1000, policy_module)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
Expand Down
8 changes: 3 additions & 5 deletions sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from torchrl.envs.libs.dm_control import DMControlEnv
from torchrl.envs.libs.gym import set_gym_backend
from torchrl.envs.utils import set_exploration_mode
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import (
DTActor,
OnlineDTActor,
Expand Down Expand Up @@ -374,13 +374,12 @@ def make_odt_model(cfg):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
cache_dist=False,
return_log_prob=False,
)

# init the lazy layers
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
Expand Down Expand Up @@ -428,13 +427,12 @@ def make_dt_model(cfg):
module=actor_module,
distribution_class=dist_class,
distribution_kwargs=dist_kwargs,
default_interaction_mode="random",
cache_dist=False,
return_log_prob=False,
)

# init the lazy layers
with torch.no_grad(), set_exploration_mode("random"):
with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM):
td = proof_environment.rollout(max_steps=100)
td["action"] = td["next", "action"]
actor(td)
Expand Down
1 change: 0 additions & 1 deletion sota-implementations/redq/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ collector:
multi_step: 1
n_steps_return: 3
max_frames_per_traj: -1
exploration_mode: random

logger:
backend: wandb
Expand Down
2 changes: 1 addition & 1 deletion sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def make_collector_offpolicy(
"init_random_frames": cfg.collector.init_random_frames,
"split_trajs": True,
# trajectories must be separated if multi-step is used
"exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode),
"exploration_type": cfg.collector.exploration_type,
}

collector = collector_helper(**collector_helper_kwargs)
Expand Down
8 changes: 4 additions & 4 deletions test/test_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
out_keys=[("data", "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand All @@ -77,8 +77,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions=
out_keys=[("data", "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
},
log_prob_key=log_prob_key,
return_log_prob=True,
Expand Down
10 changes: 5 additions & 5 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device):
d = TruncatedNormal(
*vecs,
upscale=upscale,
min=min,
max=max,
low=min,
high=max,
)
assert d.device == device
for _ in range(100):
Expand All @@ -218,7 +218,7 @@ def test_truncnormal_against_scipy(self):
high = 2
low = -1
log_pi_x = TruncatedNormal(
mu, sigma, min=low, max=high, tanh_loc=False
mu, sigma, low=low, high=high, tanh_loc=False
).log_prob(x)
pi_x = torch.exp(log_pi_x)
log_pi_x.backward(torch.ones_like(log_pi_x))
Expand Down Expand Up @@ -264,8 +264,8 @@ def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device):
d = TruncatedNormal(
*vecs,
upscale=upscale,
min=min,
max=max,
low=min,
high=max,
)
assert d.mode is not None
assert d.entropy() is not None
Expand Down
2 changes: 1 addition & 1 deletion test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3065,7 +3065,7 @@ def test_atari_preproc(self, dataset_id, tmpdir):

t = Compose(
UnsqueezeTransform(
unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")]
dim=-3, in_keys=["observation", ("next", "observation")]
),
Resize(32, in_keys=["observation", ("next", "observation")]),
RenameTransform(in_keys=["action"], out_keys=["other_action"]),
Expand Down
6 changes: 2 additions & 4 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1776,10 +1776,8 @@ def test_insert_transform(self):
not _has_tv, reason="needs torchvision dependency"
),
),
pytest.param(
partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform"
),
pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"),
pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"),
pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"),
GrayScale,
pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"),
pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"),
Expand Down
Loading

0 comments on commit 8a99a31

Please sign in to comment.