From 38566f60663a2793903b7dd7a3806891b1ef3fde Mon Sep 17 00:00:00 2001 From: Antoine Broyelle Date: Tue, 8 Oct 2024 18:04:51 +0200 Subject: [PATCH] [Feature] Ensure transformation keys have the same number of elements (#2466) --- torchrl/envs/transforms/transforms.py | 89 ++++++++++++++------------- 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index a95a14d42ad..216def16c42 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -42,6 +42,7 @@ from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import ( _unravel_key_to_tuple, + _zip_strict, expand_as_right, expand_right, NestedKey, @@ -88,7 +89,7 @@ def new_fun(self, observation_spec): _specs = observation_spec._specs in_keys = self.in_keys out_keys = self.out_keys - for in_key, out_key in zip(in_keys, out_keys): + for in_key, out_key in _zip_strict(in_keys, out_keys): if in_key in observation_spec.keys(True, True): _specs[out_key] = function(self, observation_spec[in_key].clone()) return Composite( @@ -118,7 +119,7 @@ def new_fun(self, input_spec): state_spec = state_spec.clone() in_keys_inv = self.in_keys_inv out_keys_inv = self.out_keys_inv - for in_key, out_key in zip(in_keys_inv, out_keys_inv): + for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv): if in_key != out_key: # we only change the input spec if the key is the same continue @@ -274,7 +275,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: :meth:`TransformedEnv.reset`. """ - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): value = tensordict.get(in_key, default=None) if value is not None: observation = self._apply_transform(value) @@ -291,7 +292,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: @dispatch(source="in_keys", dest="out_keys") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: """Reads the input tensordict, and for the selected keys, applies the transform.""" - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): data = tensordict.get(in_key, None) if data is not None: data = self._apply_transform(data) @@ -332,7 +333,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor: def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self.in_keys_inv: return tensordict - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): data = tensordict.get(in_key, None) if data is not None: item = self._inv_apply_transform(data) @@ -1637,7 +1638,7 @@ def _reset(self, tensordict: TensorDict, tensordict_reset: TensorDictBase): return tensordict_reset def _call(self, tensordict: TensorDict) -> TensorDict: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): val_in = tensordict.get(in_key, None) val_out = tensordict.get(out_key, None) if val_in is not None: @@ -1679,7 +1680,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: ) def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key in self.parent.full_observation_spec.keys(True): target = self.parent.full_observation_spec[in_key] elif in_key in self.parent.full_reward_spec.keys(True): @@ -3004,7 +3005,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor: def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase: """Update the episode tensordict with max pooled keys.""" _just_reset = _reset is not None - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): # Lazy init of buffers buffer_name = f"_cat_buffers_{in_key}" data = tensordict.get(in_key) @@ -3139,12 +3140,12 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase: # first sort the in_keys with strings and non-strings keys = [ (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys) if isinstance(in_key, str) ] keys += [ (in_key, out_key) - for in_key, out_key in zip(self.in_keys, self.out_keys) + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys) if not isinstance(in_key, str) ] @@ -3180,7 +3181,7 @@ def unfold_done(done, N): first_val = None if isinstance(in_key, tuple) and in_key[0] == "next": # let's get the out_key we have already processed - prev_out_key = dict(zip(self.in_keys, self.out_keys)).get( + prev_out_key = dict(_zip_strict(self.in_keys, self.out_keys)).get( in_key[1], None ) if prev_out_key is not None: @@ -3613,7 +3614,7 @@ def func(name, item): return tensordict else: # we made sure that if in_keys is not None, out_keys is not None either - for in_key, out_key in zip(in_keys, out_keys): + for in_key, out_key in _zip_strict(in_keys, out_keys): item = self._apply_transform(tensordict.get(in_key)) tensordict.set(out_key, item) return tensordict @@ -3672,7 +3673,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec: raise NotImplementedError( f"Calling transform_input_spec without a parent environment isn't supported yet for {type(self)}." ) - for in_key_inv, out_key_inv in zip(self.in_keys_inv, self.out_keys_inv): + for in_key_inv, out_key_inv in _zip_strict(self.in_keys_inv, self.out_keys_inv): if in_key_inv in full_action_spec.keys(True): _spec = full_action_spec[in_key_inv] target = "action" @@ -3706,7 +3707,7 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: full_observation_spec = output_spec["full_observation_spec"] for reward_key, reward_spec in list(full_reward_spec.items(True, True)): # find out_key that match the in_key - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if reward_key == in_key: if reward_spec.dtype != self.dtype_in: raise TypeError(f"reward_spec.dtype is not {self.dtype_in}") @@ -3722,7 +3723,7 @@ def transform_observation_spec(self, observation_spec): full_observation_spec.items(True, True) ): # find out_key that match the in_key - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if observation_key == in_key: if observation_spec.dtype != self.dtype_in: raise TypeError( @@ -3955,7 +3956,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: return result tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None) if self._rename_keys: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if out_key != in_key: tensordict_t.rename_key_(in_key, out_key) tensordict_t.set(in_key, tensordict.get(in_key)) @@ -3969,7 +3970,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: return result tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None) if self._rename_keys: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if out_key != in_key: tensordict_t.rename_key_(in_key, out_key) tensordict_t.set(in_key, tensordict.get(in_key)) @@ -3997,7 +3998,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: device=None, ) if self._rename_keys_inv: - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): if out_key != in_key: tensordict_t.rename_key_(in_key, out_key) tensordict_t.set(in_key, tensordict.get(in_key)) @@ -4030,7 +4031,7 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: def transform_action_spec(self, full_action_spec: Composite) -> Composite: full_action_spec = full_action_spec.clear_device_() - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): if in_key not in full_action_spec.keys(True, True): continue full_action_spec[out_key] = full_action_spec[in_key].to(self.device) @@ -4038,7 +4039,7 @@ def transform_action_spec(self, full_action_spec: Composite) -> Composite: def transform_state_spec(self, full_state_spec: Composite) -> Composite: full_state_spec = full_state_spec.clear_device_() - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): if in_key not in full_state_spec.keys(True, True): continue full_state_spec[out_key] = full_state_spec[in_key].to(self.device) @@ -4052,7 +4053,7 @@ def transform_output_spec(self, output_spec: Composite) -> Composite: def transform_observation_spec(self, observation_spec: Composite) -> Composite: observation_spec = observation_spec.clear_device_() - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key not in observation_spec.keys(True, True): continue observation_spec[out_key] = observation_spec[in_key].to(self.device) @@ -4060,7 +4061,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: def transform_done_spec(self, full_done_spec: Composite) -> Composite: full_done_spec = full_done_spec.clear_device_() - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key not in full_done_spec.keys(True, True): continue full_done_spec[out_key] = full_done_spec[in_key].to(self.device) @@ -4068,7 +4069,7 @@ def transform_done_spec(self, full_done_spec: Composite) -> Composite: def transform_reward_spec(self, full_reward_spec: Composite) -> Composite: full_reward_spec = full_reward_spec.clear_device_() - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key not in full_reward_spec.keys(True, True): continue full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device) @@ -5023,7 +5024,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.lock is not None: self.lock.acquire() - for key, key_out in zip(self.in_keys, self.out_keys): + for key, key_out in _zip_strict(self.in_keys, self.out_keys): if key not in tensordict.keys(include_nested=True): # TODO: init missing rewards with this # for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]: @@ -5161,7 +5162,7 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]: out = [] loc = self.loc scale = self.scale - for key, key_out in zip(self.in_keys, self.out_keys): + for key, key_out in _zip_strict(self.in_keys, self.out_keys): _out = ObservationNorm( loc=loc.get(key), scale=scale.get(key), @@ -5480,7 +5481,7 @@ def reset_keys(self): def _check_match(reset_keys, in_keys): # if this is called, the length of reset_keys and in_keys must match - for reset_key, in_key in zip(reset_keys, in_keys): + for reset_key, in_key in _zip_strict(reset_keys, in_keys): # having _reset at the root and the reward_key ("agent", "reward") is allowed # but having ("agent", "_reset") and "reward" isn't if isinstance(reset_key, tuple) and isinstance(in_key, str): @@ -5524,7 +5525,7 @@ def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: """Resets episode rewards.""" - for in_key, reset_key, out_key in zip( + for in_key, reset_key, out_key in _zip_strict( self.in_keys, self.reset_keys, self.out_keys ): _reset = _get_reset(reset_key, tensordict) @@ -5541,7 +5542,7 @@ def _step( ) -> TensorDictBase: """Updates the episode rewards with the step rewards.""" # Update episode rewards - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if in_key in next_tensordict.keys(include_nested=True): reward = next_tensordict.get(in_key) prev_reward = tensordict.get(out_key, 0.0) @@ -5563,7 +5564,7 @@ def _generate_episode_reward_spec(self) -> Composite: reward_spec = self.parent.full_reward_spec reward_spec_keys = self.parent.reward_keys # Define episode specs for all out_keys - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): if ( in_key in reward_spec_keys ): # if this out_key has a corresponding key in reward_spec @@ -5613,7 +5614,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "At least one dimension of the tensordict must be named 'time' in offline mode" ) time_dim = time_dim[0] - 1 - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): reward = tensordict[in_key] cumsum = reward.cumsum(time_dim) tensordict.set(out_key, cumsum) @@ -5791,7 +5792,13 @@ def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: # get reset signal - for step_count_key, truncated_key, terminated_key, reset_key, done_key in zip( + for ( + step_count_key, + truncated_key, + terminated_key, + reset_key, + done_key, + ) in _zip_strict( self.step_count_keys, self.truncated_keys, self.terminated_keys, @@ -5832,10 +5839,8 @@ def _reset( def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: - for step_count_key, truncated_key, done_key in zip( - self.step_count_keys, - self.truncated_keys, - self.done_keys, + for step_count_key, truncated_key, done_key in _zip_strict( + self.step_count_keys, self.truncated_keys, self.done_keys ): step_count = tensordict.get(step_count_key) next_step_count = step_count + 1 @@ -6334,7 +6339,7 @@ def _make_missing_buffer(self, tensordict, in_key, buffer_name): def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase: """Update the episode tensordict with max pooled keys.""" - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): # Lazy init of buffers buffer_name = self._buffer_name(in_key) buffer = getattr(self, buffer_name) @@ -6575,7 +6580,7 @@ def _reset( device = tensordict.device if device is None: device = torch.device("cpu") - for reset_key, init_key in zip(self.reset_keys, self.init_keys): + for reset_key, init_key in _zip_strict(self.reset_keys, self.init_keys): _reset = tensordict.get(reset_key, None) if _reset is None: done_key = _replace_last(init_key, "done") @@ -6711,7 +6716,7 @@ def __init__( def _call(self, tensordict: TensorDictBase) -> TensorDictBase: if self.create_copy: out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance) - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): try: out.rename_key_(in_key, out_key) except KeyError: @@ -6719,7 +6724,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: raise tensordict = tensordict.update(out) else: - for in_key, out_key in zip(self.in_keys, self.out_keys): + for in_key, out_key in _zip_strict(self.in_keys, self.out_keys): try: tensordict.rename_key_(in_key, out_key) except KeyError: @@ -6741,7 +6746,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: out = tensordict.select( *self.out_keys_inv, strict=not self._missing_tolerance ) - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): try: out.rename_key_(out_key, in_key) except KeyError: @@ -6750,7 +6755,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.update(out) else: - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): try: tensordict.rename_key_(out_key, in_key) except KeyError: @@ -6971,7 +6976,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: "No episode ends found to calculate the reward to go. Make sure that the number of frames_per_batch is larger than number of steps per episode." ) found = False - for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv): + for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv): if in_key in tensordict.keys(include_nested=True): found = True item = self._inv_apply_transform(tensordict.get(in_key), done)