Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Ensure transformation keys have the same number of elements #2466

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 47 additions & 42 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import (
_unravel_key_to_tuple,
_zip_strict,
expand_as_right,
expand_right,
NestedKey,
Expand Down Expand Up @@ -88,7 +89,7 @@ def new_fun(self, observation_spec):
_specs = observation_spec._specs
in_keys = self.in_keys
out_keys = self.out_keys
for in_key, out_key in zip(in_keys, out_keys):
for in_key, out_key in _zip_strict(in_keys, out_keys):
if in_key in observation_spec.keys(True, True):
_specs[out_key] = function(self, observation_spec[in_key].clone())
return Composite(
Expand Down Expand Up @@ -118,7 +119,7 @@ def new_fun(self, input_spec):
state_spec = state_spec.clone()
in_keys_inv = self.in_keys_inv
out_keys_inv = self.out_keys_inv
for in_key, out_key in zip(in_keys_inv, out_keys_inv):
for in_key, out_key in _zip_strict(in_keys_inv, out_keys_inv):
if in_key != out_key:
# we only change the input spec if the key is the same
continue
Expand Down Expand Up @@ -274,7 +275,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
:meth:`TransformedEnv.reset`.

"""
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
value = tensordict.get(in_key, default=None)
if value is not None:
observation = self._apply_transform(value)
Expand All @@ -291,7 +292,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Reads the input tensordict, and for the selected keys, applies the transform."""
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
data = tensordict.get(in_key, None)
if data is not None:
data = self._apply_transform(data)
Expand Down Expand Up @@ -332,7 +333,7 @@ def _inv_apply_transform(self, state: torch.Tensor) -> torch.Tensor:
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
if not self.in_keys_inv:
return tensordict
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
data = tensordict.get(in_key, None)
if data is not None:
item = self._inv_apply_transform(data)
Expand Down Expand Up @@ -1637,7 +1638,7 @@ def _reset(self, tensordict: TensorDict, tensordict_reset: TensorDictBase):
return tensordict_reset

def _call(self, tensordict: TensorDict) -> TensorDict:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
val_in = tensordict.get(in_key, None)
val_out = tensordict.get(out_key, None)
if val_in is not None:
Expand Down Expand Up @@ -1679,7 +1680,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)

def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key in self.parent.full_observation_spec.keys(True):
target = self.parent.full_observation_spec[in_key]
elif in_key in self.parent.full_reward_spec.keys(True):
Expand Down Expand Up @@ -3004,7 +3005,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> torch.Tensor:
def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
_just_reset = _reset is not None
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = f"_cat_buffers_{in_key}"
data = tensordict.get(in_key)
Expand Down Expand Up @@ -3139,12 +3140,12 @@ def unfolding(self, tensordict: TensorDictBase) -> TensorDictBase:
# first sort the in_keys with strings and non-strings
keys = [
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys)
if isinstance(in_key, str)
]
keys += [
(in_key, out_key)
for in_key, out_key in zip(self.in_keys, self.out_keys)
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys)
if not isinstance(in_key, str)
]

Expand Down Expand Up @@ -3180,7 +3181,7 @@ def unfold_done(done, N):
first_val = None
if isinstance(in_key, tuple) and in_key[0] == "next":
# let's get the out_key we have already processed
prev_out_key = dict(zip(self.in_keys, self.out_keys)).get(
prev_out_key = dict(_zip_strict(self.in_keys, self.out_keys)).get(
in_key[1], None
)
if prev_out_key is not None:
Expand Down Expand Up @@ -3613,7 +3614,7 @@ def func(name, item):
return tensordict
else:
# we made sure that if in_keys is not None, out_keys is not None either
for in_key, out_key in zip(in_keys, out_keys):
for in_key, out_key in _zip_strict(in_keys, out_keys):
item = self._apply_transform(tensordict.get(in_key))
tensordict.set(out_key, item)
return tensordict
Expand Down Expand Up @@ -3672,7 +3673,7 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
raise NotImplementedError(
f"Calling transform_input_spec without a parent environment isn't supported yet for {type(self)}."
)
for in_key_inv, out_key_inv in zip(self.in_keys_inv, self.out_keys_inv):
for in_key_inv, out_key_inv in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key_inv in full_action_spec.keys(True):
_spec = full_action_spec[in_key_inv]
target = "action"
Expand Down Expand Up @@ -3706,7 +3707,7 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
full_observation_spec = output_spec["full_observation_spec"]
for reward_key, reward_spec in list(full_reward_spec.items(True, True)):
# find out_key that match the in_key
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if reward_key == in_key:
if reward_spec.dtype != self.dtype_in:
raise TypeError(f"reward_spec.dtype is not {self.dtype_in}")
Expand All @@ -3722,7 +3723,7 @@ def transform_observation_spec(self, observation_spec):
full_observation_spec.items(True, True)
):
# find out_key that match the in_key
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if observation_key == in_key:
if observation_spec.dtype != self.dtype_in:
raise TypeError(
Expand Down Expand Up @@ -3955,7 +3956,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return result
tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None)
if self._rename_keys:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
Expand All @@ -3969,7 +3970,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
return result
tensordict_t = tensordict.named_apply(self._to, nested_keys=True, device=None)
if self._rename_keys:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
Expand Down Expand Up @@ -3997,7 +3998,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
device=None,
)
if self._rename_keys_inv:
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if out_key != in_key:
tensordict_t.rename_key_(in_key, out_key)
tensordict_t.set(in_key, tensordict.get(in_key))
Expand Down Expand Up @@ -4030,15 +4031,15 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:

def transform_action_spec(self, full_action_spec: Composite) -> Composite:
full_action_spec = full_action_spec.clear_device_()
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_action_spec.keys(True, True):
continue
full_action_spec[out_key] = full_action_spec[in_key].to(self.device)
return full_action_spec

def transform_state_spec(self, full_state_spec: Composite) -> Composite:
full_state_spec = full_state_spec.clear_device_()
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key not in full_state_spec.keys(True, True):
continue
full_state_spec[out_key] = full_state_spec[in_key].to(self.device)
Expand All @@ -4052,23 +4053,23 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:

def transform_observation_spec(self, observation_spec: Composite) -> Composite:
observation_spec = observation_spec.clear_device_()
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in observation_spec.keys(True, True):
continue
observation_spec[out_key] = observation_spec[in_key].to(self.device)
return observation_spec

def transform_done_spec(self, full_done_spec: Composite) -> Composite:
full_done_spec = full_done_spec.clear_device_()
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_done_spec.keys(True, True):
continue
full_done_spec[out_key] = full_done_spec[in_key].to(self.device)
return full_done_spec

def transform_reward_spec(self, full_reward_spec: Composite) -> Composite:
full_reward_spec = full_reward_spec.clear_device_()
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key not in full_reward_spec.keys(True, True):
continue
full_reward_spec[out_key] = full_reward_spec[in_key].to(self.device)
Expand Down Expand Up @@ -5023,7 +5024,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.lock is not None:
self.lock.acquire()

for key, key_out in zip(self.in_keys, self.out_keys):
for key, key_out in _zip_strict(self.in_keys, self.out_keys):
if key not in tensordict.keys(include_nested=True):
# TODO: init missing rewards with this
# for key_suffix in [_append_last(key, suffix) for suffix in ("_sum", "_ssq", "_count")]:
Expand Down Expand Up @@ -5161,7 +5162,7 @@ def to_observation_norm(self) -> Union[Compose, ObservationNorm]:
out = []
loc = self.loc
scale = self.scale
for key, key_out in zip(self.in_keys, self.out_keys):
for key, key_out in _zip_strict(self.in_keys, self.out_keys):
_out = ObservationNorm(
loc=loc.get(key),
scale=scale.get(key),
Expand Down Expand Up @@ -5480,7 +5481,7 @@ def reset_keys(self):

def _check_match(reset_keys, in_keys):
# if this is called, the length of reset_keys and in_keys must match
for reset_key, in_key in zip(reset_keys, in_keys):
for reset_key, in_key in _zip_strict(reset_keys, in_keys):
# having _reset at the root and the reward_key ("agent", "reward") is allowed
# but having ("agent", "_reset") and "reward" isn't
if isinstance(reset_key, tuple) and isinstance(in_key, str):
Expand Down Expand Up @@ -5524,7 +5525,7 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
"""Resets episode rewards."""
for in_key, reset_key, out_key in zip(
for in_key, reset_key, out_key in _zip_strict(
self.in_keys, self.reset_keys, self.out_keys
):
_reset = _get_reset(reset_key, tensordict)
Expand All @@ -5541,7 +5542,7 @@ def _step(
) -> TensorDictBase:
"""Updates the episode rewards with the step rewards."""
# Update episode rewards
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if in_key in next_tensordict.keys(include_nested=True):
reward = next_tensordict.get(in_key)
prev_reward = tensordict.get(out_key, 0.0)
Expand All @@ -5563,7 +5564,7 @@ def _generate_episode_reward_spec(self) -> Composite:
reward_spec = self.parent.full_reward_spec
reward_spec_keys = self.parent.reward_keys
# Define episode specs for all out_keys
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
if (
in_key in reward_spec_keys
): # if this out_key has a corresponding key in reward_spec
Expand Down Expand Up @@ -5613,7 +5614,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"At least one dimension of the tensordict must be named 'time' in offline mode"
)
time_dim = time_dim[0] - 1
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
reward = tensordict[in_key]
cumsum = reward.cumsum(time_dim)
tensordict.set(out_key, cumsum)
Expand Down Expand Up @@ -5791,7 +5792,13 @@ def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
# get reset signal
for step_count_key, truncated_key, terminated_key, reset_key, done_key in zip(
for (
step_count_key,
truncated_key,
terminated_key,
reset_key,
done_key,
) in _zip_strict(
self.step_count_keys,
self.truncated_keys,
self.terminated_keys,
Expand Down Expand Up @@ -5832,10 +5839,8 @@ def _reset(
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
for step_count_key, truncated_key, done_key in zip(
self.step_count_keys,
self.truncated_keys,
self.done_keys,
for step_count_key, truncated_key, done_key in _zip_strict(
self.step_count_keys, self.truncated_keys, self.done_keys
):
step_count = tensordict.get(step_count_key)
next_step_count = step_count + 1
Expand Down Expand Up @@ -6334,7 +6339,7 @@ def _make_missing_buffer(self, tensordict, in_key, buffer_name):

def _call(self, tensordict: TensorDictBase, _reset=None) -> TensorDictBase:
"""Update the episode tensordict with max pooled keys."""
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
# Lazy init of buffers
buffer_name = self._buffer_name(in_key)
buffer = getattr(self, buffer_name)
Expand Down Expand Up @@ -6575,7 +6580,7 @@ def _reset(
device = tensordict.device
if device is None:
device = torch.device("cpu")
for reset_key, init_key in zip(self.reset_keys, self.init_keys):
for reset_key, init_key in _zip_strict(self.reset_keys, self.init_keys):
_reset = tensordict.get(reset_key, None)
if _reset is None:
done_key = _replace_last(init_key, "done")
Expand Down Expand Up @@ -6711,15 +6716,15 @@ def __init__(
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.create_copy:
out = tensordict.select(*self.in_keys, strict=not self._missing_tolerance)
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
try:
out.rename_key_(in_key, out_key)
except KeyError:
if not self._missing_tolerance:
raise
tensordict = tensordict.update(out)
else:
for in_key, out_key in zip(self.in_keys, self.out_keys):
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
try:
tensordict.rename_key_(in_key, out_key)
except KeyError:
Expand All @@ -6741,7 +6746,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
out = tensordict.select(
*self.out_keys_inv, strict=not self._missing_tolerance
)
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
try:
out.rename_key_(out_key, in_key)
except KeyError:
Expand All @@ -6750,7 +6755,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:

tensordict = tensordict.update(out)
else:
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
try:
tensordict.rename_key_(out_key, in_key)
except KeyError:
Expand Down Expand Up @@ -6971,7 +6976,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
"No episode ends found to calculate the reward to go. Make sure that the number of frames_per_batch is larger than number of steps per episode."
)
found = False
for in_key, out_key in zip(self.in_keys_inv, self.out_keys_inv):
for in_key, out_key in _zip_strict(self.in_keys_inv, self.out_keys_inv):
if in_key in tensordict.keys(include_nested=True):
found = True
item = self._inv_apply_transform(tensordict.get(in_key), done)
Expand Down
Loading