From 152bc81b7c8bb57cf00a3f54769e262c1ec874df Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:00:40 +0100 Subject: [PATCH 01/10] [BugFix,Test,Benchmark] Fix graph breaks induced by device context manager ghstack-source-id: 0df2728928280a43de4abd30afed20826b0af091 Pull Request resolved: https://github.com/pytorch/rl/pull/2602 --- benchmarks/test_objectives_benchmarks.py | 89 ++++++++++++++++++++---- 1 file changed, 77 insertions(+), 12 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 9932c8ba8b7..629d83a6dd3 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -50,13 +50,13 @@ ) # Anything from 2.5, incl. nightlies, allows for fullgraph -@pytest.fixture(scope="module", autouse=True) -def set_default_device(): - cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - torch.set_default_device(device) - yield - torch.set_default_device(cur_device) +# @pytest.fixture(scope="module", autouse=True) +# def set_default_device(): +# cur_device = getattr(torch, "get_default_device", lambda: torch.device("cpu"))() +# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# torch.set_default_device(device) +# yield +# torch.set_default_device(cur_device) class setup_value_fn: @@ -173,7 +173,14 @@ def test_dqn_speed( ): if compile: torch._dynamo.reset_code_caches() - net = MLP(in_features=n_obs, out_features=n_act, depth=depth, num_cells=ncells) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + net = MLP( + in_features=n_obs, + out_features=n_act, + depth=depth, + num_cells=ncells, + device=device, + ) action_space = "one-hot" mod = QValueActor(net, in_keys=["obs"], action_space=action_space) loss = DQNLoss(value_network=mod, action_space=action_space) @@ -188,6 +195,7 @@ def test_dqn_speed( }, }, [batch], + device=device, ) loss(td) @@ -220,23 +228,27 @@ def test_ddpg_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -251,6 +263,7 @@ def test_ddpg_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor_head = Mod(actor, in_keys=["hidden"], out_keys=["action"]) @@ -291,23 +304,27 @@ def test_sac_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -322,6 +339,7 @@ def test_sac_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -374,23 +392,27 @@ def test_redq_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -405,6 +427,7 @@ def test_redq_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -460,23 +483,27 @@ def test_redq_deprec_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -491,6 +518,7 @@ def test_redq_deprec_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -544,23 +572,27 @@ def test_td3_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -575,6 +607,7 @@ def test_td3_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -633,23 +666,27 @@ def test_cql_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch] td = TensorDict( @@ -664,6 +701,7 @@ def test_cql_speed( }, }, batch, + device=device, ) common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -724,23 +762,27 @@ def test_a2c_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -757,6 +799,7 @@ def test_a2c_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -775,7 +818,9 @@ def test_a2c_speed( critic(td.clone()) loss = A2CLoss(actor_network=actor, critic_network=critic) - advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device + ) advantage(td) loss(td) @@ -816,23 +861,27 @@ def test_ppo_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -849,6 +898,7 @@ def test_ppo_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -867,7 +917,9 @@ def test_ppo_speed( critic(td.clone()) loss = ClipPPOLoss(actor_network=actor, critic_network=critic) - advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device + ) advantage(td) loss(td) @@ -908,23 +960,27 @@ def test_reinforce_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -941,6 +997,7 @@ def test_reinforce_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -959,7 +1016,9 @@ def test_reinforce_speed( critic(td.clone()) loss = ReinforceLoss(actor_network=actor, critic_network=critic) - advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) + advantage = GAE( + value_network=critic, gamma=0.99, lmbda=0.95, shifted=True, device=device + ) advantage(td) loss(td) @@ -1000,29 +1059,34 @@ def test_iql_speed( ): if compile: torch._dynamo.reset_code_caches() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") common_net = MLP( num_cells=ncells, in_features=n_obs, depth=3, out_features=n_hidden, + device=device, ) actor_net = MLP( num_cells=ncells, in_features=n_hidden, depth=2, out_features=2 * n_act, + device=device, ) value_net = MLP( in_features=n_hidden, num_cells=ncells, depth=2, out_features=1, + device=device, ) qvalue_net = MLP( in_features=n_hidden + n_act, num_cells=ncells, depth=2, out_features=1, + device=device, ) batch = [batch, T] td = TensorDict( @@ -1039,6 +1103,7 @@ def test_iql_speed( }, batch, names=[None, "time"], + device=device, ) common = Mod(common_net, in_keys=["obs"], out_keys=["hidden"]) actor = ProbSeq( @@ -1087,4 +1152,4 @@ def loss_and_bw(td): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) + pytest.main([__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + unknown) From 097d8ad9879f8e83c89c7d8da0a4727146355cfd Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:36 +0100 Subject: [PATCH 02/10] [Feature] spec.is_empty(recurse) ghstack-source-id: faa3b1df5133c77462d6dd013d3854d684cc7e94 Pull Request resolved: https://github.com/pytorch/rl/pull/2596 --- torchrl/data/tensor_specs.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 2ef74bb4521..1f31db01ec7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4278,8 +4278,22 @@ def shape(self, value: torch.Size): ) self._shape = _size(value) - def is_empty(self): - """Whether the composite spec contains specs or not.""" + def is_empty(self, recurse: bool = False): + """Whether the composite spec contains specs or not. + + Args: + recurse (bool): whether to recursively assess if the spec is empty. + If ``True``, will return ``True`` if there are no leaves. If ``False`` + (default) will return whether there is any spec defined at the root level. + + """ + if recurse: + for spec in self._specs.values(): + if spec is None: + continue + if isinstance(spec, Composite) and spec.is_empty(recurse=True): + continue + return False return len(self._specs) == 0 @property From 2e82cab191132d6272d4b598a539bf88c5e961a0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:37 +0100 Subject: [PATCH 03/10] [Feature] Composite.batch_size ghstack-source-id: 621884a559a71e80a4be36c7ba984fd08be47952 Pull Request resolved: https://github.com/pytorch/rl/pull/2597 --- torchrl/data/tensor_specs.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 1f31db01ec7..5404beb0ec0 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4257,6 +4257,14 @@ def __new__(cls, *args, **kwargs): cls._locked = False return super().__new__(cls) + @property + def batch_size(self): + return self._shape + + @batch_size.setter + def batch_size(self, value: torch.Size): + self._shape = value + @property def shape(self): return self._shape From 8d16c12bd783c4e36dc24dca56c7cc24f115d37c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:38 +0100 Subject: [PATCH 04/10] [Feature] Composite.pop ghstack-source-id: 64d5bd736657ef56e37d57726dfcfd25b16b699f Pull Request resolved: https://github.com/pytorch/rl/pull/2598 --- torchrl/data/tensor_specs.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 5404beb0ec0..b701b2f6bf7 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4311,6 +4311,34 @@ def ndim(self): def ndimension(self): return len(self.shape) + def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any: + """Removes and returns the value associated with the specified key from the composite spec. + + This method searches for the given key in the composite spec, removes it, and returns its associated value. + If the key is not found, it returns the provided default value if specified, otherwise raises a `KeyError`. + + Args: + key (NestedKey): + The key to be removed from the composite spec. It can be a single key or a nested key. + default (Any, optional): + The value to return if the specified key is not found in the composite spec. + If not provided and the key is not found, a `KeyError` is raised. + + Returns: + Any: The value associated with the specified key that was removed from the composite spec. + + Raises: + KeyError: If the specified key is not found in the composite spec and no default value is provided. + """ + key = unravel_key(key) + if key in self.keys(True, True): + result = self[key] + del self[key] + return result + elif default is not NO_DEFAULT: + return default + raise KeyError(f"{key} not found in composite spec.") + def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.") From 83e0b0568d95aba0a462da383db7639e51f35be7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:40 +0100 Subject: [PATCH 05/10] [Feature] Composite.separates ghstack-source-id: fbfc4308a81cd96ecc61723df8c0eb870c442def Pull Request resolved: https://github.com/pytorch/rl/pull/2599 --- torchrl/data/tensor_specs.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index b701b2f6bf7..32e61bc3ede 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -4339,6 +4339,33 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> Any: return default raise KeyError(f"{key} not found in composite spec.") + def separates(self, *keys: NestedKey, default: Any = None) -> Composite: + """Splits the composite spec by extracting specified keys and their associated values into a new composite spec. + + This method iterates over the provided keys, removes them from the current composite spec, and adds them to a new + composite spec. If a key is not found, the specified default value is used. The new composite spec is returned. + + Args: + *keys (NestedKey): + One or more keys to be extracted from the composite spec. Each key can be a single key or a nested key. + default (Any, optional): + The value to use if a specified key is not found in the composite spec. Defaults to `None`. + + Returns: + Composite: A new composite spec containing the extracted keys and their associated values. + + Note: + If none of the specified keys are found, the method returns `None`. + """ + out = None + for key in keys: + result = self.pop(key, default=default) + if result is not None: + if out is None: + out = Composite(batch_size=self.batch_size, device=self.device) + out[key] = result + return out + def set(self, name, spec): if self.locked: raise RuntimeError("Cannot modify a locked Composite.") From de153bf454df9fab7c473dea687802ce42c679ea Mon Sep 17 00:00:00 2001 From: carschandler <92899389+carschandler@users.noreply.github.com> Date: Sun, 24 Nov 2024 02:19:24 -0600 Subject: [PATCH 06/10] [Doc] fix several typos (#2603) --- torchrl/modules/models/models.py | 2 +- tutorials/sphinx-tutorials/getting-started-3.py | 4 ++-- tutorials/sphinx-tutorials/rb_tutorial.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/modules/models/models.py b/torchrl/modules/models/models.py index 49c0b0961ab..cad4065f54a 100644 --- a/torchrl/modules/models/models.py +++ b/torchrl/modules/models/models.py @@ -955,7 +955,7 @@ class DuelingCnnDQNet(nn.Module): >>> cnn_kwargs = { ... 'num_cells': [32, 64, 64], ... 'strides': [4, 2, 1], - ... 'kernels': [8, 4, 3], + ... 'kernel_sizes': [8, 4, 3], ... } mlp_kwargs (dict or list of dicts, optional): kwargs for the advantage diff --git a/tutorials/sphinx-tutorials/getting-started-3.py b/tutorials/sphinx-tutorials/getting-started-3.py index 594cb7392c0..70ffe37a005 100644 --- a/tutorials/sphinx-tutorials/getting-started-3.py +++ b/tutorials/sphinx-tutorials/getting-started-3.py @@ -152,7 +152,7 @@ indices = buffer.extend(data) ################################# -# We can check that the buffer now has the same number of elements than what +# We can check that the buffer now has the same number of elements as what # we got from the collector: assert len(buffer) == collector.frames_per_batch @@ -174,7 +174,7 @@ # Next steps # ---------- # -# - You can have look at other multirpocessed +# - You can have look at other multiprocessed # collectors such as :class:`~torchrl.collectors.collectors.MultiSyncDataCollector` or # :class:`~torchrl.collectors.collectors.MultiaSyncDataCollector`. # - TorchRL also offers distributed collectors if you have multiple nodes to diff --git a/tutorials/sphinx-tutorials/rb_tutorial.py b/tutorials/sphinx-tutorials/rb_tutorial.py index 4f5ecb4936d..2a852f0e364 100644 --- a/tutorials/sphinx-tutorials/rb_tutorial.py +++ b/tutorials/sphinx-tutorials/rb_tutorial.py @@ -168,7 +168,7 @@ buffer_lazytensor = ReplayBuffer(storage=LazyTensorStorage(size)) ###################################################################### -# Let us create a batch of data of size ``torch.Size([3])` with 2 tensors +# Let us create a batch of data of size ``torch.Size([3])`` with 2 tensors # stored in it: # From 00d3199ecc50d004ea2219b51b64066e6e485014 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 24 Nov 2024 09:13:41 +0100 Subject: [PATCH 07/10] [Feature] EnvBase.check_env_specs ghstack-source-id: 332dbf92db496c71c5ce6aba340ad123eac0f5d6 Pull Request resolved: https://github.com/pytorch/rl/pull/2600 --- torchrl/envs/common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index a7c004bfcc5..0611af20b45 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -390,6 +390,11 @@ def __init__( self.batch_size = torch.Size(batch_size) self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + @wraps(check_env_specs_func) + def check_env_specs(self, *args, **kwargs): + return check_env_specs_func(self, *args, **kwargs) + + check_env_specs.__doc__ = check_env_specs_func.__doc__ @classmethod def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs): From a1e21f598fa587ac848f2cc1bcb44d36c32d06ba Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Sun, 24 Nov 2024 18:19:49 +0200 Subject: [PATCH 08/10] [BugFix] Wrong spec returned (#2604) --- torchrl/envs/common.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 0611af20b45..949f5e3b621 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -390,6 +390,7 @@ def __init__( self.batch_size = torch.Size(batch_size) self._run_type_checks = run_type_checks self._allow_done_after_reset = allow_done_after_reset + @wraps(check_env_specs_func) def check_env_specs(self, *args, **kwargs): return check_env_specs_func(self, *args, **kwargs) @@ -1550,7 +1551,7 @@ def action_spec_unbatched(self, spec: Composite): @property def full_observation_spec_unbatched(self) -> Composite: """Returns the observation spec of the env as if it had no batch dimensions.""" - return self._make_single_env_spec(self.full_action_spec) + return self._make_single_env_spec(self.full_observation_spec) @full_observation_spec_unbatched.setter def full_observation_spec_unbatched(self, spec: Composite): @@ -1570,7 +1571,7 @@ def observation_spec_unbatched(self, spec: Composite): @property def full_reward_spec_unbatched(self) -> Composite: """Returns the reward spec of the env as if it had no batch dimensions.""" - return self._make_single_env_spec(self.full_action_spec) + return self._make_single_env_spec(self.full_reward_spec) @full_reward_spec_unbatched.setter def full_reward_spec_unbatched(self, spec: Composite): @@ -1590,7 +1591,7 @@ def reward_spec_unbatched(self, spec: Composite): @property def full_done_spec_unbatched(self) -> Composite: """Returns the done spec of the env as if it had no batch dimensions.""" - return self._make_single_env_spec(self.full_action_spec) + return self._make_single_env_spec(self.full_done_spec) @full_done_spec_unbatched.setter def full_done_spec_unbatched(self, spec: Composite): From d90b9e3d1a1a92e43aafa0b7693e45766795ea46 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 11:34:30 +0000 Subject: [PATCH 09/10] [BugFix] Fix imports ghstack-source-id: db85f2611c1c0b22e9179b4fdd6c2dcea78ac8dd Pull Request resolved: https://github.com/pytorch/rl/pull/2605 --- benchmarks/test_objectives_benchmarks.py | 5 ++++- setup.py | 13 ++++++++++++- test/test_rlhf.py | 16 ++++++---------- torchrl/envs/common.py | 19 ++++++++++--------- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 629d83a6dd3..2e8fe407b3d 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -1152,4 +1152,7 @@ def loss_and_bw(td): if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() - pytest.main([__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + unknown) + pytest.main( + [__file__, "--capture", "no", "--exitfirst", "--benchmark-group-by", "func"] + + unknown + ) diff --git a/setup.py b/setup.py index 823ec307052..33d9d5c7268 100644 --- a/setup.py +++ b/setup.py @@ -209,7 +209,18 @@ def _main(argv): "dm_control": ["dm_control"], "gym_continuous": ["gymnasium<1.0", "mujoco"], "rendering": ["moviepy<2.0.0"], - "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], + "tests": [ + "pytest", + "pyyaml", + "pytest-instafail", + "scipy", + "pytest-mock", + "pytest-cov", + "pytest-benchmark", + "pytest-rerunfailures", + "pytest-error-for-skips", + "", + ], "utils": [ "tensorboard", "wandb", diff --git a/test/test_rlhf.py b/test/test_rlhf.py index 486ddbef127..1c09856dd36 100644 --- a/test/test_rlhf.py +++ b/test/test_rlhf.py @@ -298,12 +298,10 @@ def test_tensordict_tokenizer( "Lettuce in, it's cold out here!", ] } - if not truncation and return_tensordict and max_length == 10: - with pytest.raises(ValueError, match="TensorDict conversion only supports"): - out = process(example) - return out = process(example) - if return_tensordict: + if not truncation and return_tensordict and max_length == 10: + assert out.get("input_ids").shape[-1] == -1 + elif return_tensordict: assert out.get("input_ids").shape[-1] == max_length else: obj = out.get("input_ids") @@ -346,12 +344,10 @@ def test_prompt_tensordict_tokenizer( ], "label": ["right", "wrong", "right", "wrong", "right"], } - if not truncation and return_tensordict and max_length == 10: - with pytest.raises(ValueError, match="TensorDict conversion only supports"): - out = process(example) - return out = process(example) - if return_tensordict: + if not truncation and return_tensordict and max_length == 10: + assert out.get("input_ids").shape[-1] == -1 + elif return_tensordict: assert out.get("input_ids").shape[-1] == max_length else: obj = out.get("input_ids") diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index 949f5e3b621..8adf36b0019 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -6,9 +6,9 @@ from __future__ import annotations import abc -import functools import warnings from copy import deepcopy +from functools import partial, wraps from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple import numpy as np @@ -33,6 +33,7 @@ _StepMDP, _terminated_or_truncated, _update_during_reset, + check_env_specs as check_env_specs_func, get_available_libraries, ) @@ -2035,7 +2036,7 @@ def _register_gym( if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2084,7 +2085,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2138,7 +2139,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2195,7 +2196,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2254,7 +2255,7 @@ def _register_gym( # noqa: F811 ) if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymWrapper, entry_point=entry_point, info_keys=info_keys, @@ -2293,7 +2294,7 @@ def _register_gym( # noqa: F811 if entry_point is None: entry_point = cls - entry_point = functools.partial( + entry_point = partial( _TorchRLGymnasiumWrapper, entry_point=entry_point, info_keys=info_keys, @@ -3422,11 +3423,11 @@ def _get_sync_func(policy_device, env_device): if policy_device is not None and policy_device.type == "cuda": if env_device is None or env_device.type == "cuda": return torch.cuda.synchronize - return functools.partial(torch.cuda.synchronize, device=policy_device) + return partial(torch.cuda.synchronize, device=policy_device) if env_device is not None and env_device.type == "cuda": if policy_device is None: return torch.cuda.synchronize - return functools.partial(torch.cuda.synchronize, device=env_device) + return partial(torch.cuda.synchronize, device=env_device) return torch.cuda.synchronize if torch.backends.mps.is_available(): return torch.mps.synchronize From c8676f4a87df65bff0ccc42ea09942ef73ce4d9a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 13:34:30 +0000 Subject: [PATCH 10/10] [BugFix] Account for terminating data in SAC losses ghstack-source-id: dc1870292786c262b4ab6a221b3afb551e0efb9b Pull Request resolved: https://github.com/pytorch/rl/pull/2606 --- test/test_cost.py | 119 ++++++++++++++++++++++++++++++++++++++ torchrl/objectives/sac.py | 51 +++++++++++++--- 2 files changed, 162 insertions(+), 8 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 598b9ba004d..c48b4a28b99 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4459,6 +4459,69 @@ def test_sac_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_sac_terminating( + self, action_key, observation_key, reward_key, done_key, terminated_key, version + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + action_key=action_key, + out_keys=["state_action_value"], + ) + if version == 1: + value = self._create_mock_value(observation_key=observation_key) + else: + value = None + + loss = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + value_network=value, + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + torch.manual_seed(self.seed) + + SoftUpdate(loss, eps=0.5) + + done = td.get(("next", done_key)) + while not (done.any() and not done.all()): + done.bernoulli_(0.1) + obs_nan = td.get(("next", terminated_key)) + obs_nan[done.squeeze(-1)] = float("nan") + + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": done, + f"next_{terminated_key}": obs_nan, + f"next_{observation_key}": td.get(("next", observation_key)), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + assert loss(td).isfinite().all() + def test_state_dict(self, version): if version == 1: pytest.skip("Test not implemented for version 1.") @@ -5112,6 +5175,62 @@ def test_discrete_sac_notensordict( assert loss_actor == loss_val_td["loss_actor"] assert loss_alpha == loss_val_td["loss_alpha"] + @pytest.mark.parametrize("action_key", ["action", "action2"]) + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_discrete_sac_terminating( + self, action_key, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + td = self._create_mock_data_sac( + action_key=action_key, + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + + actor = self._create_mock_actor( + observation_key=observation_key, action_key=action_key + ) + qvalue = self._create_mock_qvalue( + observation_key=observation_key, + ) + + loss = DiscreteSACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_actions=actor.spec[action_key].space.n, + action_space="one-hot", + ) + loss.set_keys( + action=action_key, + reward=reward_key, + done=done_key, + terminated=terminated_key, + ) + + SoftUpdate(loss, eps=0.5) + + torch.manual_seed(0) + done = td.get(("next", done_key)) + while not (done.any() and not done.all()): + done = done.bernoulli_(0.1) + obs_none = td.get(("next", observation_key)) + obs_none[done.squeeze(-1)] = float("nan") + kwargs = { + action_key: td.get(action_key), + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": done, + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": obs_none, + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + assert loss(td).isfinite().all() + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) def test_discrete_sac_reduction(self, reduction): torch.manual_seed(self.seed) diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 3fb34678d02..52efb3d312b 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -16,7 +16,7 @@ from tensordict import TensorDict, TensorDictBase, TensorDictParams from tensordict.nn import dispatch, TensorDictModule -from tensordict.utils import NestedKey +from tensordict.utils import expand_right, NestedKey from torch import Tensor from torchrl.data.tensor_specs import Composite, TensorSpec from torchrl.data.utils import _find_action_space @@ -711,13 +711,37 @@ def _compute_target_v2(self, tensordict) -> Tensor: with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): - next_tensordict = tensordict.get("next").clone(False) - next_dist = self.actor_network.get_dist(next_tensordict) + next_tensordict = tensordict.get("next").copy() + # Check done state and avoid passing these to the actor + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + next_dist = self.actor_network.get_dist(next_tensordict_select) next_action = next_dist.rsample() - next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = compute_log_prob( next_dist, next_action, self.tensor_keys.log_prob ) + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + if mask.ndim < next_action.ndim: + mask = expand_right( + mask, (*mask.shape, *next_action.shape[mask.ndim :]) + ) + next_action = next_action.new_zeros(mask.shape).masked_scatter_( + mask, next_action + ) + mask = ~done.squeeze(-1) + if mask.ndim < next_sample_log_prob.ndim: + mask = expand_right( + mask, + (*mask.shape, *next_sample_log_prob.shape[mask.ndim :]), + ) + next_sample_log_prob = next_sample_log_prob.new_zeros( + mask.shape + ).masked_scatter_(mask, next_sample_log_prob) + next_tensordict.set(self.tensor_keys.action, next_action) # get q-values next_tensordict_expand = self._vmap_qnetworkN0( @@ -1194,15 +1218,21 @@ def _compute_target(self, tensordict) -> Tensor: with torch.no_grad(): next_tensordict = tensordict.get("next").clone(False) + done = next_tensordict.get(self.tensor_keys.done) + if done is not None and done.any(): + next_tensordict_select = next_tensordict[~done.squeeze(-1)] + else: + next_tensordict_select = next_tensordict + # get probs and log probs for actions computed from "next" with self.actor_network_params.to_module(self.actor_network): - next_dist = self.actor_network.get_dist(next_tensordict) - next_prob = next_dist.probs - next_log_prob = torch.log(torch.where(next_prob == 0, 1e-8, next_prob)) + next_dist = self.actor_network.get_dist(next_tensordict_select) + next_log_prob = next_dist.logits + next_prob = next_log_prob.exp() # get q-values for all actions next_tensordict_expand = self._vmap_qnetworkN0( - next_tensordict, self.target_qvalue_network_params + next_tensordict_select, self.target_qvalue_network_params ) next_action_value = next_tensordict_expand.get( self.tensor_keys.action_value @@ -1212,6 +1242,11 @@ def _compute_target(self, tensordict) -> Tensor: next_state_value = next_action_value.min(0)[0] - self._alpha * next_log_prob # unlike in continuous SAC, we can compute the exact expectation over all discrete actions next_state_value = (next_prob * next_state_value).sum(-1).unsqueeze(-1) + if next_tensordict_select is not next_tensordict: + mask = ~done.squeeze(-1) + next_state_value = next_state_value.new_zeros( + mask.shape + ).masked_scatter_(mask, next_state_value) tensordict.set( ("next", self.value_estimator.tensor_keys.value), next_state_value