Skip to content

Commit

Permalink
[Feature] Ensure transformation keys have the same number of elements (
Browse files Browse the repository at this point in the history
  • Loading branch information
antoinebrl authored Oct 8, 2024
1 parent b116151 commit 38566f6
Showing 1 changed file with 47 additions and 42 deletions.
89 changes: 47 additions & 42 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import (
_unravel_key_to_tuple,
_zip_strict,
expand_as_right,
expand_right,
NestedKey,
Expand Down Expand Up @@ -88,7 +89,7 @@ def new_fun(self, observation_spec):
_specs = observation_spec._specs
in_keys = self.in_keys
out_keys = self.out_keys
for in_key, out_key in zip(in_keys, out_keys):
for in_key, out_key in _zip_strict(in_keys, out_keys):
if in_key in observation_spec.keys(True, True):
_specs[out_key] = function(self, observation_spec[in_key].clone())
return Composite(
Expand Down Expand Up @@ -118,7 +119,7 @@ def new_fun(self, input_spec):
state_spec = state_spec.clone()
in_keys_inv = self.in_keys_inv
out_keys_inv = self.out_keys_inv
for in_key, out_key in zip(in_keys_inv, out_keys_inv):
for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv):
if in_key != out_key:
# we only change the input spec if the key is the same
continue
Expand Down Expand Up @@ -274,7 +275,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
:meth:`TransformedEnv.reset`.
"""
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
value = tensordict.get(in_key, default=None)
if value is not None:
observation = self._apply_transform(value)
Expand All @@ -291,7 +292,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the transform."""
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
data = tensordict.get(in_key, None)
if data is not None:
data = self._apply_transform(data)
Expand Down Expand Up @@ -332,7 +333,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
if not self.in_keys_inv:
return tensordict
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
data = tensordict.get(in_key, None)
if data is not None:
item = self._inv_apply_transform(data)
Expand Down Expand Up @@ -1637,7 +1638,7 @@ def _reset(self, tensordict: TensorDict, tensordict_reset: TensorDictBase):
return tensordict_reset

def _call(self, tensordict: TensorDict) -> TensorDict:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
val_in = tensordict.get(in_key, None)
val_out = tensordict.get(out_key, None)
if val_in is not None:
Expand Down Expand Up @@ -1679,7 +1680,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key in self.parent.full_observation_spec.keys(True):
target = self.parent.full_observation_spec[in_key]
elif in_key in self.parent.full_reward_spec.keys(True):
Expand Down Expand Up @@ -3004,7 +3005,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor:
def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
_just_reset = _reset is not None
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = f"_cat_buffers_{in_key}"
data = tensordict.get(in_key)
Expand Down Expand Up @@ -3139,12 +3140,12 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
# first sort the in_keys with strings and non-strings
keys = [
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys)
if isinstance(in_key, str)
]
keys += [
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys)
if not isinstance(in_key, str)
]

Expand Down Expand Up @@ -3180,7 +3181,7 @@ def unfold_done(done, N):
first_val = None
if isinstance(in_key, tuple) and in_key[0] == "next":
# let's get the out_key we have already processed
prev_out_key = dict(zip(self.in_keys, self.out_keys)).get(
prev_out_key = dict(_zip_strict(self.in_keys, self.out_keys)).get(
in_key[1], None
)
if prev_out_key is not None:
Expand Down Expand Up @@ -3613,7 +3614,7 @@ def func(name, item):
return tensordict
else:
# we made sure that if in_keys is not None, out_keys is not None either
for in_key, out_key in zip(in_keys, out_keys):
for in_key, out_key in _zip_strict(in_keys, out_keys):
item = self._apply_transform(tensordict.get(in_key))
tensordict.set(out_key, item)
return tensordict
Expand Down Expand Up @@ -3672,7 +3673,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
raise NotImplementedError(
f"Calling transform_input_spec without a parent environment isn't supported yet for {type(self)}."
)
for in_key_inv, out_key_inv in zip(self.in_keys_inv, self.out_keys_inv):
for in_key_inv, out_key_inv in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key_inv in full_action_spec.keys(True):
_spec = full_action_spec[in_key_inv]
target = "action"
Expand Down Expand Up @@ -3706,7 +3707,7 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
full_observation_spec = output_spec["full_observation_spec"]
for reward_key, reward_spec in list(full_reward_spec.items(True, True)):
# find out_key that match the in_key
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if reward_key == in_key:
if reward_spec.dtype != self.dtype_in:
raise TypeError(f"reward_spec.dtype is not {self.dtype_in}")
Expand All @@ -3722,7 +3723,7 @@ def transform_observation_spec(self, observation_spec):
full_observation_spec.items(True, True)
):
# find out_key that match the in_key
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if observation_key == in_key:
if observation_spec.dtype != self.dtype_in:
raise TypeError(
Expand Down Expand Up @@ -3955,7 +3956,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return result
tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None)
if self._rename_keys:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
Expand All @@ -3969,7 +3970,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
return result
tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None)
if self._rename_keys:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
Expand Down Expand Up @@ -3997,7 +3998,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
device=None,
)
if self._rename_keys_inv:
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
Expand Down Expand Up @@ -4030,15 +4031,15 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:

def transform_action_spec(self, full_action_spec: Composite) -> Composite:
full_action_spec = full_action_spec.clear_device_()
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_action_spec.keys(True, True):
continue
full_action_spec[out_key] = full_action_spec[in_key].to(self.device)
return full_action_spec

def transform_state_spec(self, full_state_spec: Composite) -> Composite:
full_state_spec = full_state_spec.clear_device_()
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_state_spec.keys(True, True):
continue
full_state_spec[out_key] = full_state_spec[in_key].to(self.device)
Expand All @@ -4052,23 +4053,23 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:

def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec = observation_spec.clear_device_()
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in observation_spec.keys(True, True):
continue
observation_spec[out_key] = observation_spec[in_key].to(self.device)
return observation_spec

def transform_done_spec(self, full_done_spec: Composite) -> Composite:
full_done_spec = full_done_spec.clear_device_()
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_done_spec.keys(True, True):
continue
full_done_spec[out_key] = full_done_spec[in_key].to(self.device)
return full_done_spec

def transform_reward_spec(self, full_reward_spec: Composite) -> Composite:
full_reward_spec = full_reward_spec.clear_device_()
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_reward_spec.keys(True, True):
continue
full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device)
Expand Down Expand Up @@ -5023,7 +5024,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.lock is not None:
self.lock.acquire()

for key, key_out in zip(self.in_keys, self.out_keys):
for key, key_out in _zip_strict(self.in_keys, self.out_keys):
if key not in tensordict.keys(include_nested=True):
# TODO: init missing rewards with this
# for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]:
Expand Down Expand Up @@ -5161,7 +5162,7 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
out = []
loc = self.loc
scale = self.scale
for key, key_out in zip(self.in_keys, self.out_keys):
for key, key_out in _zip_strict(self.in_keys, self.out_keys):
_out = ObservationNorm(
loc=loc.get(key),
scale=scale.get(key),
Expand Down Expand Up @@ -5480,7 +5481,7 @@ def reset_keys(self):

def _check_match(reset_keys, in_keys):
# if this is called, the length of reset_keys and in_keys must match
for reset_key, in_key in zip(reset_keys, in_keys):
for reset_key, in_key in _zip_strict(reset_keys, in_keys):
# having _reset at the root and the reward_key ("agent", "reward") is allowed
# but having ("agent", "_reset") and "reward" isn't
if isinstance(reset_key, tuple) and isinstance(in_key, str):
Expand Down Expand Up @@ -5524,7 +5525,7 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
"""Resets episode rewards."""
for in_key, reset_key, out_key in zip(
for in_key, reset_key, out_key in _zip_strict(
self.in_keys, self.reset_keys, self.out_keys
):
_reset = _get_reset(reset_key, tensordict)
Expand All @@ -5541,7 +5542,7 @@ def _step(
) -> TensorDictBase:
"""Updates the episode rewards with the step rewards."""
# Update episode rewards
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key in next_tensordict.keys(include_nested=True):
reward = next_tensordict.get(in_key)
prev_reward = tensordict.get(out_key, 0.0)
Expand All @@ -5563,7 +5564,7 @@ def _generate_episode_reward_spec(self) -> Composite:
reward_spec = self.parent.full_reward_spec
reward_spec_keys = self.parent.reward_keys
# Define episode specs for all out_keys
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if (
in_key in reward_spec_keys
): # if this out_key has a corresponding key in reward_spec
Expand Down Expand Up @@ -5613,7 +5614,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"At least one dimension of the tensordict must be named 'time' in offline mode"
)
time_dim = time_dim[0] - 1
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
reward = tensordict[in_key]
cumsum = reward.cumsum(time_dim)
tensordict.set(out_key, cumsum)
Expand Down Expand Up @@ -5791,7 +5792,13 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# get reset signal
for step_count_key, truncated_key, terminated_key, reset_key, done_key in zip(
for (
step_count_key,
truncated_key,
terminated_key,
reset_key,
done_key,
) in _zip_strict(
self.step_count_keys,
self.truncated_keys,
self.terminated_keys,
Expand Down Expand Up @@ -5832,10 +5839,8 @@ def _reset(
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
for step_count_key, truncated_key, done_key in zip(
self.step_count_keys,
self.truncated_keys,
self.done_keys,
for step_count_key, truncated_key, done_key in _zip_strict(
self.step_count_keys, self.truncated_keys, self.done_keys
):
step_count = tensordict.get(step_count_key)
next_step_count = step_count + 1
Expand Down Expand Up @@ -6334,7 +6339,7 @@ def _make_missing_buffer(self, tensordict, in_key, buffer_name):

def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = self._buffer_name(in_key)
buffer = getattr(self, buffer_name)
Expand Down Expand Up @@ -6575,7 +6580,7 @@ def _reset(
device = tensordict.device
if device is None:
device = torch.device("cpu")
for reset_key, init_key in zip(self.reset_keys, self.init_keys):
for reset_key, init_key in _zip_strict(self.reset_keys, self.init_keys):
_reset = tensordict.get(reset_key, None)
if _reset is None:
done_key = _replace_last(init_key, "done")
Expand Down Expand Up @@ -6711,15 +6716,15 @@ def __init__(
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.create_copy:
out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance)
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
try:
out.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
tensordict = tensordict.update(out)
else:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
try:
tensordict.rename_key_(in_key, out_key)
except KeyError:
Expand All @@ -6741,7 +6746,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
out = tensordict.select(
*self.out_keys_inv, strict=not self._missing_tolerance
)
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
try:
out.rename_key_(out_key, in_key)
except KeyError:
Expand All @@ -6750,7 +6755,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:

tensordict = tensordict.update(out)
else:
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
try:
tensordict.rename_key_(out_key, in_key)
except KeyError:
Expand Down Expand Up @@ -6971,7 +6976,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
"No episode ends found to calculate the reward to go. Make sure that the number of frames_per_batch is larger than number of steps per episode."
)
found = False
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key in tensordict.keys(include_nested=True):
found = True
item = self._inv_apply_transform(tensordict.get(in_key), done)
Expand Down

0 comments on commit 38566f6

Please sign in to comment.