From ce0a74cb45e960c7d23abf38e56b3a066f159e51 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 8 Oct 2024 12:54:34 +0100 Subject: [PATCH] [Deprecations] Deprecate in view of v0.6 release ghstack-source-id: 105e1e2215d14d774086e5da31ad366457a7e84c Pull Request resolved: https://github.com/pytorch/rl/pull/2446 --- docs/source/reference/envs.rst | 2 - docs/source/reference/modules.rst | 6 +- .../collectors/multi_nodes/ray_train.py | 8 +- .../decision_transformer/utils.py | 8 +- sota-implementations/redq/config.yaml | 1 - sota-implementations/redq/utils.py | 2 +- test/test_actors.py | 8 +- test/test_distributions.py | 10 +- test/test_libs.py | 2 +- test/test_rb.py | 6 +- test/test_transforms.py | 128 ++++++++---------- torchrl/collectors/collectors.py | 30 +--- torchrl/collectors/distributed/generic.py | 5 - torchrl/collectors/distributed/rpc.py | 5 - torchrl/collectors/distributed/sync.py | 5 - torchrl/envs/__init__.py | 2 - torchrl/envs/transforms/r3m.py | 2 +- torchrl/envs/transforms/transforms.py | 60 +++++--- torchrl/envs/transforms/vip.py | 2 +- torchrl/envs/utils.py | 13 -- torchrl/modules/distributions/continuous.py | 55 +------- torchrl/modules/models/exploration.py | 2 +- .../tensordict_module/probabilistic.py | 2 - torchrl/objectives/common.py | 4 +- torchrl/objectives/value/advantages.py | 13 +- torchrl/trainers/helpers/collectors.py | 7 +- tutorials/sphinx-tutorials/pendulum.py | 2 +- 27 files changed, 139 insertions(+), 251 deletions(-) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index afef09aa312..3578cbfd79f 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -996,11 +996,9 @@ Helpers RandomPolicy check_env_specs - exploration_mode #deprecated exploration_type get_available_libraries make_composite_from_td - set_exploration_mode #deprecated set_exploration_type step_mdp terminated_or_truncated diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index 2d6a6344970..e1642868228 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -62,13 +62,13 @@ Exploration wrappers and modules To efficiently explore the environment, TorchRL proposes a series of modules that will override the action sampled by the policy by a noisier version. -Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_mode`: -if the exploration is set to ``"random"``, the exploration is active. In all +Their behavior is controlled by :func:`~torchrl.envs.utils.exploration_type`: +if the exploration is set to ``ExplorationType.RANDOM``, the exploration is active. In all other cases, the action written in the tensordict is simply the network output. .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. - The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on this module. .. currentmodule:: torchrl.modules diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index b05e92619fa..5697d88dc61 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -26,7 +26,7 @@ TransformedEnv, ) from torchrl.envs.libs.gym import GymEnv -from torchrl.envs.utils import check_env_specs, set_exploration_mode +from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE @@ -85,8 +85,8 @@ in_keys=["loc", "scale"], distribution_class=TanhNormal, distribution_kwargs={ - "min": env.action_spec.space.low, - "max": env.action_spec.space.high, + "low": env.action_spec.space.low, + "high": env.action_spec.space.high, }, return_log_prob=True, ) @@ -201,7 +201,7 @@ stepcount_str = f"step count (max): {logs['step_count'][-1]}" logs["lr"].append(optim.param_groups[0]["lr"]) lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}" - with set_exploration_mode("mean"), torch.no_grad(): + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): # execute a rollout with the trained policy eval_rollout = env.rollout(1000, policy_module) logs["eval reward"].append(eval_rollout["next", "reward"].mean().item()) diff --git a/sota-implementations/decision_transformer/utils.py b/sota-implementations/decision_transformer/utils.py index 409833c75fa..ee2cc6e424c 100644 --- a/sota-implementations/decision_transformer/utils.py +++ b/sota-implementations/decision_transformer/utils.py @@ -38,7 +38,7 @@ ) from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.envs.libs.gym import set_gym_backend -from torchrl.envs.utils import set_exploration_mode +from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ( DTActor, OnlineDTActor, @@ -374,13 +374,12 @@ def make_odt_model(cfg): module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, - default_interaction_mode="random", cache_dist=False, return_log_prob=False, ) # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] actor(td) @@ -428,13 +427,12 @@ def make_dt_model(cfg): module=actor_module, distribution_class=dist_class, distribution_kwargs=dist_kwargs, - default_interaction_mode="random", cache_dist=False, return_log_prob=False, ) # init the lazy layers - with torch.no_grad(), set_exploration_mode("random"): + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): td = proof_environment.rollout(max_steps=100) td["action"] = td["next", "action"] actor(td) diff --git a/sota-implementations/redq/config.yaml b/sota-implementations/redq/config.yaml index e60191c0f93..818f3386fda 100644 --- a/sota-implementations/redq/config.yaml +++ b/sota-implementations/redq/config.yaml @@ -36,7 +36,6 @@ collector: multi_step: 1 n_steps_return: 3 max_frames_per_traj: -1 - exploration_mode: random logger: backend: wandb diff --git a/sota-implementations/redq/utils.py b/sota-implementations/redq/utils.py index dd922372cbb..8a3c9ae3f79 100644 --- a/sota-implementations/redq/utils.py +++ b/sota-implementations/redq/utils.py @@ -1021,7 +1021,7 @@ def make_collector_offpolicy( "init_random_frames": cfg.collector.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used - "exploration_type": ExplorationType.from_str(cfg.collector.exploration_mode), + "exploration_type": cfg.collector.exploration_type, } collector = collector_helper(**collector_helper_kwargs) diff --git a/test/test_actors.py b/test/test_actors.py index 439094e922a..b81f322b708 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -54,8 +54,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, @@ -77,8 +77,8 @@ def test_probabilistic_actor_nested_delta(log_prob_key, nested_dim=5, n_actions= out_keys=[("data", "action")], distribution_class=TanhDelta, distribution_kwargs={ - "min": action_spec.space.low, - "max": action_spec.space.high, + "low": action_spec.space.low, + "high": action_spec.space.high, }, log_prob_key=log_prob_key, return_log_prob=True, diff --git a/test/test_distributions.py b/test/test_distributions.py index 53bfda343a2..8a5b651531e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -190,8 +190,8 @@ def test_truncnormal(self, min, max, vecs, upscale, shape, device): d = TruncatedNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=min, + high=max, ) assert d.device == device for _ in range(100): @@ -218,7 +218,7 @@ def test_truncnormal_against_scipy(self): high = 2 low = -1 log_pi_x = TruncatedNormal( - mu, sigma, min=low, max=high, tanh_loc=False + mu, sigma, low=low, high=high, tanh_loc=False ).log_prob(x) pi_x = torch.exp(log_pi_x) log_pi_x.backward(torch.ones_like(log_pi_x)) @@ -264,8 +264,8 @@ def test_truncnormal_mode(self, min, max, vecs, upscale, shape, device): d = TruncatedNormal( *vecs, upscale=upscale, - min=min, - max=max, + low=min, + high=max, ) assert d.mode is not None assert d.entropy() is not None diff --git a/test/test_libs.py b/test/test_libs.py index 87c69bf000c..6fc2979607d 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3065,7 +3065,7 @@ def test_atari_preproc(self, dataset_id, tmpdir): t = Compose( UnsqueezeTransform( - unsqueeze_dim=-3, in_keys=["observation", ("next", "observation")] + dim=-3, in_keys=["observation", ("next", "observation")] ), Resize(32, in_keys=["observation", ("next", "observation")]), RenameTransform(in_keys=["action"], out_keys=["other_action"]), diff --git a/test/test_rb.py b/test/test_rb.py index 34b34b5b486..24b33f89795 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1776,10 +1776,8 @@ def test_insert_transform(self): not _has_tv, reason="needs torchvision dependency" ), ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), GrayScale, pytest.param(partial(ObservationNorm, loc=1, scale=2), id="ObservationNorm"), pytest.param(partial(CatFrames, dim=-3, N=4), id="CatFrames"), diff --git a/test/test_transforms.py b/test/test_transforms.py index 589c32809cc..55b9a73e054 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -5627,7 +5627,7 @@ def test_transform_model(self): class TestUnsqueezeTransform(TransformBase): - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5635,14 +5635,10 @@ class TestUnsqueezeTransform(TransformBase): "keys", [["observation", ("some_other", "nested_key")], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_default_devices()) - def test_transform_no_env( - self, keys, size, nchannels, batch, device, unsqueeze_dim - ): + def test_transform_no_env(self, keys, size, nchannels, batch, device, dim): torch.manual_seed(0) dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device) - unsqueeze = UnsqueezeTransform( - unsqueeze_dim, in_keys=keys, allow_positive_dim=True - ) + unsqueeze = UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True) td = TensorDict( { key: torch.randn(*batch, *size, nchannels, 16, 16, device=device) @@ -5652,16 +5648,16 @@ def test_transform_no_env( device=device, ) td.set("dont touch", dont_touch.clone()) - if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch): + if dim >= 0 and dim < len(batch): with pytest.raises(RuntimeError, match="batch dimension mismatch"): unsqueeze(td) return unsqueeze(td) expected_size = [*batch, *size, nchannels, 16, 16] - if unsqueeze_dim < 0: - expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(unsqueeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys: @@ -5669,7 +5665,7 @@ def test_transform_no_env( batch, size, nchannels, - unsqueeze_dim, + dim, ) assert (td.get("dont touch") == dont_touch).all() @@ -5688,7 +5684,7 @@ def test_transform_no_env( for key in keys: assert observation_spec[key].shape == expected_size - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5704,13 +5700,11 @@ def test_transform_no_env( [("next", "observation_pixels")], ], ) - def test_unsqueeze_inv( - self, keys, keys_inv, size, nchannels, batch, device, unsqueeze_dim - ): + def test_unsqueeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim): torch.manual_seed(0) keys_total = set(keys + keys_inv) unsqueeze = UnsqueezeTransform( - unsqueeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -5726,8 +5720,8 @@ def test_unsqueeze_inv( for key in keys_total.difference(keys_inv): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[unsqueeze_dim] == 1: - del expected_size[unsqueeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys_inv: assert td_modif.get(key).shape == torch.Size(expected_size) # for key in keys_inv: @@ -5787,7 +5781,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): except RuntimeError: pass - @pytest.mark.parametrize("unsqueeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5795,13 +5789,11 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): "keys", [["observation", "some_other_key"], ["observation_pixels"]] ) @pytest.mark.parametrize("device", get_default_devices()) - def test_transform_compose( - self, keys, size, nchannels, batch, device, unsqueeze_dim - ): + def test_transform_compose(self, keys, size, nchannels, batch, device, dim): torch.manual_seed(0) dont_touch = torch.randn(*batch, *size, nchannels, 16, 16, device=device) unsqueeze = Compose( - UnsqueezeTransform(unsqueeze_dim, in_keys=keys, allow_positive_dim=True) + UnsqueezeTransform(dim, in_keys=keys, allow_positive_dim=True) ) td = TensorDict( { @@ -5812,16 +5804,16 @@ def test_transform_compose( device=device, ) td.set("dont touch", dont_touch.clone()) - if unsqueeze_dim >= 0 and unsqueeze_dim < len(batch): + if dim >= 0 and dim < len(batch): with pytest.raises(RuntimeError, match="batch dimension mismatch"): unsqueeze(td) return unsqueeze(td) expected_size = [*batch, *size, nchannels, 16, 16] - if unsqueeze_dim < 0: - expected_size.insert(len(expected_size) + unsqueeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(unsqueeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys: @@ -5829,7 +5821,7 @@ def test_transform_compose( batch, size, nchannels, - unsqueeze_dim, + dim, ) assert (td.get("dont touch") == dont_touch).all() @@ -5865,10 +5857,10 @@ def test_transform_env(self, out_keys): check_env_specs(env) @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) - @pytest.mark.parametrize("unsqueeze_dim", [-1, 1]) - def test_transform_model(self, out_keys, unsqueeze_dim): + @pytest.mark.parametrize("dim", [-1, 1]) + def test_transform_model(self, out_keys, dim): t = UnsqueezeTransform( - unsqueeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -5878,21 +5870,21 @@ def test_transform_model(self, out_keys, unsqueeze_dim): ) t(td) expected_shape = [3, 4] - if unsqueeze_dim >= 0: - expected_shape.insert(unsqueeze_dim, 1) + if dim >= 0: + expected_shape.insert(dim, 1) else: - expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1) + expected_shape.insert(len(expected_shape) + dim + 1, 1) if out_keys is None: assert td["observation"].shape == torch.Size(expected_shape) else: assert td[out_keys[0]].shape == torch.Size(expected_shape) @pytest.mark.parametrize("out_keys", [None, ["stuff"]]) - @pytest.mark.parametrize("unsqueeze_dim", [-1, 1]) + @pytest.mark.parametrize("dim", [-1, 1]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) - def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): + def test_transform_rb(self, rbclass, out_keys, dim): t = UnsqueezeTransform( - unsqueeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -5905,10 +5897,10 @@ def test_transform_rb(self, rbclass, out_keys, unsqueeze_dim): rb.extend(td) td = rb.sample(2) expected_shape = [2, 3, 4] - if unsqueeze_dim >= 0: - expected_shape.insert(unsqueeze_dim, 1) + if dim >= 0: + expected_shape.insert(dim, 1) else: - expected_shape.insert(len(expected_shape) + unsqueeze_dim + 1, 1) + expected_shape.insert(len(expected_shape) + dim + 1, 1) if out_keys is None: assert td["observation"].shape == torch.Size(expected_shape) else: @@ -5932,7 +5924,7 @@ def test_transform_inverse(self): class TestSqueezeTransform(TransformBase): - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5953,12 +5945,12 @@ class TestSqueezeTransform(TransformBase): ], ) def test_transform_no_env( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim + self, keys, keys_inv, size, nchannels, batch, device, dim ): torch.manual_seed(0) keys_total = set(keys + keys_inv) squeeze = SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -5973,12 +5965,12 @@ def test_transform_no_env( for key in keys_total.difference(keys): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[squeeze_dim] == 1: - del expected_size[squeeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys: assert td.get(key).shape == torch.Size(expected_size) - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -5998,15 +5990,13 @@ def test_transform_no_env( [("next", "observation_pixels")], ], ) - def test_squeeze_inv( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim - ): + def test_squeeze_inv(self, keys, keys_inv, size, nchannels, batch, device, dim): torch.manual_seed(0) - if squeeze_dim >= 0: - squeeze_dim = squeeze_dim + len(batch) + if dim >= 0: + dim = dim + len(batch) keys_total = set(keys + keys_inv) squeeze = SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) td = TensorDict( { @@ -6021,14 +6011,14 @@ def test_squeeze_inv( for key in keys_total.difference(keys_inv): assert td.get(key).shape == torch.Size(expected_size) - if squeeze_dim < 0: - expected_size.insert(len(expected_size) + squeeze_dim + 1, 1) + if dim < 0: + expected_size.insert(len(expected_size) + dim + 1, 1) else: - expected_size.insert(squeeze_dim, 1) + expected_size.insert(dim, 1) expected_size = torch.Size(expected_size) for key in keys_inv: - assert td.get(key).shape == torch.Size(expected_size), squeeze_dim + assert td.get(key).shape == torch.Size(expected_size), dim @property def _circular_transform(self): @@ -6101,7 +6091,7 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): except RuntimeError: pass - @pytest.mark.parametrize("squeeze_dim", [1, -2]) + @pytest.mark.parametrize("dim", [1, -2]) @pytest.mark.parametrize("nchannels", [1, 3]) @pytest.mark.parametrize("batch", [[], [2], [2, 4]]) @pytest.mark.parametrize("size", [[], [4]]) @@ -6114,13 +6104,13 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv): "keys_inv", [[], ["action", "some_other_key"], [("next", "observation_pixels")]] ) def test_transform_compose( - self, keys, keys_inv, size, nchannels, batch, device, squeeze_dim + self, keys, keys_inv, size, nchannels, batch, device, dim ): torch.manual_seed(0) keys_total = set(keys + keys_inv) squeeze = Compose( SqueezeTransform( - squeeze_dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True + dim, in_keys=keys, in_keys_inv=keys_inv, allow_positive_dim=True ) ) td = TensorDict( @@ -6136,8 +6126,8 @@ def test_transform_compose( for key in keys_total.difference(keys): assert td.get(key).shape == torch.Size(expected_size) - if expected_size[squeeze_dim] == 1: - del expected_size[squeeze_dim] + if expected_size[dim] == 1: + del expected_size[dim] for key in keys: assert td.get(key).shape == torch.Size(expected_size) @@ -6154,9 +6144,9 @@ def test_transform_env(self, keys_inv): @pytest.mark.parametrize("out_keys", [None, ["obs_sq"]]) def test_transform_model(self, out_keys): - squeeze_dim = 1 + dim = 1 t = SqueezeTransform( - squeeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -6175,9 +6165,9 @@ def test_transform_model(self, out_keys): @pytest.mark.parametrize("out_keys", [None, ["obs_sq"]]) @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) def test_transform_rb(self, out_keys, rbclass): - squeeze_dim = -2 + dim = -2 t = SqueezeTransform( - squeeze_dim, + dim, in_keys=["observation"], out_keys=out_keys, allow_positive_dim=True, @@ -8925,10 +8915,8 @@ def test_batch_unlocked_with_batch_size_transformed(device): pytest.param( partial(FlattenObservation, first_dim=-3, last_dim=-3), id="FlattenObservation" ), - pytest.param( - partial(UnsqueezeTransform, unsqueeze_dim=-1), id="UnsqueezeTransform" - ), - pytest.param(partial(SqueezeTransform, squeeze_dim=-1), id="SqueezeTransform"), + pytest.param(partial(UnsqueezeTransform, dim=-1), id="UnsqueezeTransform"), + pytest.param(partial(SqueezeTransform, dim=-1), id="SqueezeTransform"), GrayScale, pytest.param( partial(ObservationNorm, in_keys=["observation"]), id="ObservationNorm" diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 9ccd2e2aa80..3acc4bd8300 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -58,7 +58,6 @@ from torchrl.envs.transforms import StepCounter, TransformedEnv from torchrl.envs.utils import ( _aggregate_end_of_traj, - _convert_exploration_type, _make_compatible_policy, ExplorationType, RandomPolicy, @@ -489,7 +488,6 @@ def __init__( postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, split_trajs: bool | None = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode: str | None = None, return_same_td: bool = False, reset_when_done: bool = True, interruptor=None, @@ -502,9 +500,6 @@ def __init__( from torchrl.envs.batched_envs import BatchedEnvBase self.closed = True - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if create_env_kwargs is None: create_env_kwargs = {} if not isinstance(create_env_fn, EnvBase): @@ -1472,7 +1467,7 @@ class _MultiDataCollector(DataCollectorBase): A ``cat_results`` value of ``-1`` will always concatenate results along the time dimension. This should be preferred over the default. Intermediate values are also accepted. - Defaults to ``0``. + Defaults to ``"stack"``. .. note:: From v0.5, this argument will default to ``"stack"`` for a better interoperability with the rest of the library. @@ -1516,7 +1511,6 @@ def __init__( postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float = None, @@ -1529,9 +1523,6 @@ def __init__( replay_buffer_chunk: bool = True, trust_policy: bool = None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) self.closed = True self.num_workers = len(create_env_fn) @@ -1675,10 +1666,12 @@ def __init__( self.cat_results = cat_results def _check_replay_buffer_init(self): + if self.replay_buffer is None: + return is_init = getattr(self.replay_buffer._storage, "initialized", True) if not is_init: if isinstance(self.create_env_fn[0], EnvCreator): - fake_td = self.create_env_fn[0].tensordict + fake_td = self.create_env_fn[0].meta_data.tensordict elif isinstance(self.create_env_fn[0], EnvBase): fake_td = self.create_env_fn[0].fake_tensordict() else: @@ -2173,19 +2166,6 @@ def iterator(self) -> Iterator[TensorDictBase]: cat_results = self.cat_results if cat_results is None: cat_results = "stack" - warnings.warn( - f"`cat_results` was not specified in the constructor of {type(self).__name__}. " - f"For MultiSyncDataCollector, `cat_results` indicates how the data should " - f"be packed: the preferred option and current default is `cat_results='stack'` " - f"which provides the best interoperability across torchrl components. " - f"Other accepted values are `cat_results=0` (previous behavior) and " - f"`cat_results=-1` (cat along time dimension). Among these two, the latter " - f"should be preferred for consistency across environment configurations. " - f"Currently, the default value is `'stack'`." - f"From v0.6 onward, this warning will be removed. " - f"To suppress this warning, set `cat_results` to the desired value.", - category=DeprecationWarning, - ) self.buffers = {} dones = [False for _ in range(self.num_workers)] @@ -2770,7 +2750,6 @@ def __init__( postproc: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, split_trajs: Optional[bool] = None, exploration_type: ExplorationType = DEFAULT_EXPLORATION_TYPE, - exploration_mode=None, reset_when_done: bool = True, update_at_each_batch: bool = False, preemptive_threshold: float = None, @@ -2795,7 +2774,6 @@ def __init__( env_device=env_device, storing_device=storing_device, exploration_type=exploration_type, - exploration_mode=exploration_mode, reset_when_done=reset_when_done, update_at_each_batch=update_at_each_batch, preemptive_threshold=preemptive_threshold, diff --git a/torchrl/collectors/distributed/generic.py b/torchrl/collectors/distributed/generic.py index 65e6987b4aa..729b8a48171 100644 --- a/torchrl/collectors/distributed/generic.py +++ b/torchrl/collectors/distributed/generic.py @@ -34,7 +34,6 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -419,7 +418,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class: Type = SyncDataCollector, collector_kwargs: dict = None, num_workers_per_collector: int = 1, @@ -431,9 +429,6 @@ def __init__( launcher: str = "submitit", tcp_port: int = None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector diff --git a/torchrl/collectors/distributed/rpc.py b/torchrl/collectors/distributed/rpc.py index 816364cf84a..73247df4b0c 100644 --- a/torchrl/collectors/distributed/rpc.py +++ b/torchrl/collectors/distributed/rpc.py @@ -24,7 +24,6 @@ ) from torchrl.collectors.utils import _NON_NN_POLICY_WEIGHTS, split_trajectories from torchrl.data.utils import CloudpickleWrapper -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -275,7 +274,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -288,9 +286,6 @@ def __init__( visible_devices=None, tensorpipe_options=None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector elif collector_class == "sync": diff --git a/torchrl/collectors/distributed/sync.py b/torchrl/collectors/distributed/sync.py index 744bce1446f..481fb70cc31 100644 --- a/torchrl/collectors/distributed/sync.py +++ b/torchrl/collectors/distributed/sync.py @@ -34,7 +34,6 @@ from torchrl.data.utils import CloudpickleWrapper from torchrl.envs.common import EnvBase from torchrl.envs.env_creator import EnvCreator -from torchrl.envs.utils import _convert_exploration_type SUBMITIT_ERR = None try: @@ -285,7 +284,6 @@ def __init__( postproc: Callable | None = None, split_trajs: bool = False, exploration_type: "ExporationType" = DEFAULT_EXPLORATION_TYPE, # noqa - exploration_mode: str = None, collector_class=SyncDataCollector, collector_kwargs=None, num_workers_per_collector=1, @@ -296,9 +294,6 @@ def __init__( launcher="submitit", tcp_port=None, ): - exploration_type = _convert_exploration_type( - exploration_mode=exploration_mode, exploration_type=exploration_type - ) if collector_class == "async": collector_class = MultiaSyncDataCollector diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index c8b7fd4aafb..d0d92251b69 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -102,12 +102,10 @@ from .utils import ( check_env_specs, check_marl_grouping, - exploration_mode, exploration_type, ExplorationType, make_composite_from_td, MarlGroupMapType, - set_exploration_mode, set_exploration_type, step_mdp, ) diff --git a/torchrl/envs/transforms/r3m.py b/torchrl/envs/transforms/r3m.py index d4505a4d240..bdc8af1eefa 100644 --- a/torchrl/envs/transforms/r3m.py +++ b/torchrl/envs/transforms/r3m.py @@ -315,7 +315,7 @@ def _init(self): unsqueeze = UnsqueezeTransform( in_keys=in_keys, out_keys=in_keys, - unsqueeze_dim=-4, + dim=-4, ) transforms.append(unsqueeze) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a95a14d42ad..8655002d971 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1347,11 +1347,11 @@ def _apply_transform(self, observation: torch.FloatTensor) -> torch.Tensor: @_apply_to_composite def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: observation_spec = self._pixel_observation(observation_spec) - unsqueeze_dim = [1] if self._should_unsqueeze(observation_spec) else [] + dim = [1] if self._should_unsqueeze(observation_spec) else [] if not self.shape_tolerant or observation_spec.shape[-1] == 3: observation_spec.shape = torch.Size( [ - *unsqueeze_dim, + *dim, *observation_spec.shape[:-3], observation_spec.shape[-1], observation_spec.shape[-3], @@ -2136,41 +2136,42 @@ class UnsqueezeTransform(Transform): """Inserts a dimension of size one at the specified position. Args: - unsqueeze_dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim + dim (int): dimension to unsqueeze. Must be negative (or allow_positive_dim must be turned on). + + Keyword Args: allow_positive_dim (bool, optional): if ``True``, positive dimensions are accepted. - :obj:`UnsqueezeTransform` will map these to the n^th feature dimension + `UnsqueezeTransform`` will map these to the n^th feature dimension (ie n^th dimension after batch size of parent env) of the input tensor, - independently from the tensordict batch size (ie positive dims may be + independently of the tensordict batch size (ie positive dims may be dangerous in contexts where tensordict of different batch dimension are passed). Defaults to False, ie. non-negative dimensions are not permitted. + in_keys (list of NestedKeys): input entries (read). + out_keys (list of NestedKeys): input entries (write). Defaults to ``in_keys`` if + not provided. + in_keys_inv (list of NestedKeys): input entries (read) during :meth:`~.inv` calls. + out_keys_inv (list of NestedKeys): input entries (write) during :meth:`~.inv` calls. + Defaults to ``in_keys_in`` if not provided. """ invertible = True @classmethod def __new__(cls, *args, **kwargs): - cls._unsqueeze_dim = None + cls._dim = None return super().__new__(cls) def __init__( self, dim: int = None, + *, allow_positive_dim: bool = False, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, - **kwargs, ): - if "unsqueeze_dim" in kwargs: - warnings.warn( - "The `unsqueeze_dim` kwarg will be removed in v0.6. Please use `dim` instead." - ) - dim = kwargs["unsqueeze_dim"] - elif dim is None: - raise TypeError("dim must be provided.") if in_keys is None: in_keys = [] # default if out_keys is None: @@ -2190,22 +2191,26 @@ def __init__( raise RuntimeError( "dim should be smaller than 0 to accommodate for " "envs of different batch_sizes. Turn allow_positive_dim to accommodate " - "for positive unsqueeze_dim." + "for positive dim." ) self._dim = dim @property def unsqueeze_dim(self): + return self.dim + + @property + def dim(self): if self._dim >= 0 and self.parent is not None: return len(self.parent.batch_size) + self._dim return self._dim def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - observation = observation.unsqueeze(self.unsqueeze_dim) + observation = observation.unsqueeze(self.dim) return observation def _inv_apply_transform(self, observation: torch.Tensor) -> torch.Tensor: - observation = observation.squeeze(self.unsqueeze_dim) + observation = observation.squeeze(self.dim) return observation def _transform_spec(self, spec: TensorSpec): @@ -2252,7 +2257,7 @@ def _reset( def __repr__(self) -> str: s = ( - f"{self.__class__.__name__}(unsqueeze_dim={self.unsqueeze_dim}, in_keys={self.in_keys}, out_keys={self.out_keys}," + f"{self.__class__.__name__}(dim={self.dim}, in_keys={self.in_keys}, out_keys={self.out_keys}," f" in_keys_inv={self.in_keys_inv}, out_keys_inv={self.out_keys_inv})" ) return s @@ -2262,14 +2267,14 @@ class SqueezeTransform(UnsqueezeTransform): """Removes a dimension of size one at the specified position. Args: - squeeze_dim (int): dimension to squeeze. + dim (int): dimension to squeeze. """ invertible = True def __init__( self, - squeeze_dim: int, + dim: int | None = None, *args, in_keys: Optional[Sequence[str]] = None, out_keys: Optional[Sequence[str]] = None, @@ -2277,8 +2282,19 @@ def __init__( out_keys_inv: Optional[Sequence[str]] = None, **kwargs, ): + if dim is None: + if "squeeze_dim" in kwargs: + warnings.warn( + f"squeeze_dim will be deprecated in favor of dim arg in {type(self).__name__}." + ) + dim = kwargs.pop("squeeze_dim") + else: + raise TypeError( + f"dim must be passed to {type(self).__name__} constructor." + ) + super().__init__( - squeeze_dim, + dim, *args, in_keys=in_keys, out_keys=out_keys, @@ -2289,7 +2305,7 @@ def __init__( @property def squeeze_dim(self): - return super().unsqueeze_dim + return super().dim _apply_transform = UnsqueezeTransform._inv_apply_transform _inv_apply_transform = UnsqueezeTransform._apply_transform diff --git a/torchrl/envs/transforms/vip.py b/torchrl/envs/transforms/vip.py index 556eacf579c..a28e490c4f1 100644 --- a/torchrl/envs/transforms/vip.py +++ b/torchrl/envs/transforms/vip.py @@ -285,7 +285,7 @@ def _init(self): unsqueeze = UnsqueezeTransform( in_keys=in_keys, out_keys=in_keys, - unsqueeze_dim=-4, + dim=-4, ) transforms.append(unsqueeze) diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 9701e96ef62..f1724326d2a 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -32,13 +32,8 @@ from tensordict.base import _is_leaf_nontensor from tensordict.nn import TensorDictModule, TensorDictModuleBase from tensordict.nn.probabilistic import ( # noqa - # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! - # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. - # See more details: https://github.com/pytorch/rl/issues/1016 - interaction_mode as exploration_mode, interaction_type as exploration_type, InteractionType as ExplorationType, - set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) from tensordict.utils import is_non_tensor, NestedKey @@ -55,9 +50,7 @@ from torchrl.data.utils import check_no_exclusive_keys, CloudpickleWrapper __all__ = [ - "exploration_mode", "exploration_type", - "set_exploration_mode", "set_exploration_type", "ExplorationType", "check_env_specs", @@ -79,12 +72,6 @@ ) -def _convert_exploration_type(*, exploration_mode, exploration_type): - if exploration_mode is not None: - return ExplorationType.from_str(exploration_mode) - return exploration_type - - class _classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() diff --git a/torchrl/modules/distributions/continuous.py b/torchrl/modules/distributions/continuous.py index 33dfe6aa1df..debb836d6fa 100644 --- a/torchrl/modules/distributions/continuous.py +++ b/torchrl/modules/distributions/continuous.py @@ -212,13 +212,6 @@ class TruncatedNormal(D.Independent): "scale": constraints.greater_than(1e-6), } - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - DeprecationWarning, - ) - def __init__( self, loc: torch.Tensor, @@ -227,14 +220,7 @@ def __init__( low: Union[torch.Tensor, float] = -1.0, high: Union[torch.Tensor, float] = 1.0, tanh_loc: bool = False, - **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") err_msg = "TanhNormal high values must be strictly greater than low values" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): @@ -392,13 +378,6 @@ class TanhNormal(FasterTransformedDistribution): num_params = 2 - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - DeprecationWarning, - ) - def __init__( self, loc: torch.Tensor, @@ -411,13 +390,6 @@ def __init__( safe_tanh: bool = True, **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") - if not isinstance(loc, torch.Tensor): loc = torch.as_tensor(loc, dtype=torch.get_default_dtype()) if not isinstance(scale, torch.Tensor): @@ -530,15 +502,10 @@ def root_dist(self): @property def mode(self): - warnings.warn( - "This computation of the mode is based on an inaccurate estimation of the mode " - "given the base_dist mode. " - "To use a more stable implementation of the mode, use dist.get_mode() method instead. " - "To silence this warning, consider using the DETERMINISTIC exploration_type." - "This implementation will be removed in v0.6.", - category=DeprecationWarning, + raise RuntimeError( + f"The distribution {type(self).__name__} has not analytical mode. " + f"Use ExplorationMode.DETERMINISTIC to get a deterministic sample from it." ) - return self.deterministic_sample @property def deterministic_sample(self): @@ -702,13 +669,6 @@ class TanhDelta(FasterTransformedDistribution): "loc": constraints.real, } - def _warn_minmax(self): - warnings.warn( - f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} " - f"and will be removed entirely in v0.6. ", - category=DeprecationWarning, - ) - def __init__( self, param: torch.Tensor, @@ -717,15 +677,7 @@ def __init__( event_dims: int = 1, atol: float = 1e-6, rtol: float = 1e-6, - **kwargs, ): - if "max" in kwargs: - self._warn_minmax() - high = kwargs.pop("max") - if "min" in kwargs: - self._warn_minmax() - low = kwargs.pop("min") - minmax_msg = "high value has been found to be equal or less than low value" if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor): if not (high > low).all(): @@ -767,7 +719,6 @@ def __init__( rtol=rtol, batch_shape=batch_shape, event_shape=event_shape, - **kwargs, ) super().__init__(base, t) diff --git a/torchrl/modules/models/exploration.py b/torchrl/modules/models/exploration.py index 720934a6809..d69a85fd685 100644 --- a/torchrl/modules/models/exploration.py +++ b/torchrl/modules/models/exploration.py @@ -553,7 +553,7 @@ class ConsistentDropout(_DropoutNd): .. note:: Unlike other exploration modules, :class:`~torchrl.modules.ConsistentDropoutModule` uses the ``train``/``eval`` mode to comply with the regular `Dropout` API in PyTorch. - The :func:`~torchrl.envs.utils.set_exploration_mode` context manager will have no effect on + The :func:`~torchrl.envs.utils.set_exploration_type` context manager will have no effect on this module. Args: diff --git a/torchrl/modules/tensordict_module/probabilistic.py b/torchrl/modules/tensordict_module/probabilistic.py index 4b38b19c699..483d9b90eea 100644 --- a/torchrl/modules/tensordict_module/probabilistic.py +++ b/torchrl/modules/tensordict_module/probabilistic.py @@ -104,7 +104,6 @@ def __init__( out_keys: Optional[Union[NestedKey, List[NestedKey]]] = None, spec: Optional[TensorSpec] = None, safe: bool = False, - default_interaction_mode: str = None, default_interaction_type: str = InteractionType.DETERMINISTIC, distribution_class: Type = Delta, distribution_kwargs: Optional[dict] = None, @@ -117,7 +116,6 @@ def __init__( in_keys=in_keys, out_keys=out_keys, default_interaction_type=default_interaction_type, - default_interaction_mode=default_interaction_mode, distribution_class=distribution_class, distribution_kwargs=distribution_kwargs, return_log_prob=return_log_prob, diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index cd4e47ef336..a1c70612484 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -97,8 +97,8 @@ class LossModule(TensorDictModuleBase, metaclass=_LossMeta): >>> loss.set_keys(action="action2") .. note:: When a policy that is wrapped or augmented with an exploration module is passed - to the loss, we want to deactivate the exploration through ``set_exploration_mode()`` where - ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or + to the loss, we want to deactivate the exploration through ``set_exploration_type()`` where + ```` is either ``ExplorationType.MEAN``, ``ExplorationType.MODE`` or ``ExplorationType.DETERMINISTIC``. The default value is ``DETERMINISTIC`` and it is set through the ``deterministic_sampling_mode`` loss attribute. If another exploration mode is required (or if ``DETERMINISTIC`` is not available), one can diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index b7db2e8242e..68a4b1604cb 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -412,18 +412,13 @@ def value_estimate( @property def is_functional(self): - if isinstance(self.value_network, nn.Module): - return is_functional(self.value_network) - elif self.value_network is None: - return None - else: - raise RuntimeError("Cannot determine if value network is functional.") + # legacy + return False @property def is_stateless(self): - if not self.is_functional: - return False - return self.value_network._is_stateless + # legacy + return False def _next_value(self, tensordict, target_params, kwargs): step_td = step_mdp(tensordict, keep_other=False) diff --git a/torchrl/trainers/helpers/collectors.py b/torchrl/trainers/helpers/collectors.py index b192d115a54..efdde1a1c63 100644 --- a/torchrl/trainers/helpers/collectors.py +++ b/torchrl/trainers/helpers/collectors.py @@ -19,7 +19,6 @@ from torchrl.data.postprocs import MultiStep from torchrl.envs.batched_envs import ParallelEnv from torchrl.envs.common import EnvBase -from torchrl.envs.utils import ExplorationType def sync_async_collector( @@ -304,7 +303,7 @@ def make_collector_offpolicy( "init_random_frames": cfg.init_random_frames, "split_trajs": True, # trajectories must be separated if multi-step is used - "exploration_type": ExplorationType.from_str(cfg.exploration_mode), + "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) @@ -358,7 +357,7 @@ def make_collector_onpolicy( "storing_device": cfg.collector_device, "split_trajs": True, # trajectories must be separated in online settings - "exploration_mode": cfg.exploration_mode, + "exploration_type": cfg.exploration_type, } collector = collector_helper(**collector_helper_kwargs) @@ -398,7 +397,7 @@ class OnPolicyCollectorConfig: # for each of these parallel wrappers. If env_per_collector=num_workers, no parallel wrapper is created seed: int = 42 # seed used for the environment, pytorch and numpy. - exploration_mode: str = "random" + exploration_type: str = "random" # exploration mode of the data collector. async_collection: bool = False # whether data collection should be done asynchrously. Asynchrounous data collection means diff --git a/tutorials/sphinx-tutorials/pendulum.py b/tutorials/sphinx-tutorials/pendulum.py index 94bd8427e30..1593d42a0ec 100644 --- a/tutorials/sphinx-tutorials/pendulum.py +++ b/tutorials/sphinx-tutorials/pendulum.py @@ -609,7 +609,7 @@ def __init__(self, td_params=None, seed=None, device="cpu"): env, # ``Unsqueeze`` the observations that we will concatenate UnsqueezeTransform( - unsqueeze_dim=-1, + dim=-1, in_keys=["th", "thdot"], in_keys_inv=["th", "thdot"], ),