Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't override terminal observation when using AutoResetWrapper #69

Merged
merged 9 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_readme() -> str:
]
DOCS_REQUIRE = [
"sphinx",
"sphinx-autodoc-typehints",
"sphinx-autodoc-typehints>=1.21.5",
"sphinx-rtd-theme",
]

Expand Down
67 changes: 66 additions & 1 deletion src/seals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,74 @@


class AutoResetWrapper(gym.Wrapper):
"""Hides done=True and auto-resets at the end of each episode."""
"""Hides done=True and auto-resets at the end of each episode.

Depending on the flag 'discard_terminal_observation', either discards the terminal
observation or pads with an additional 'reset transition'. The former is the default
behavior.
In the latter case, the action taken during the 'reset transition' will not have an
effect, the reward will be constant (set by the wrapper argument `reset_reward`,
which has default value 0.0), and info an empty dictionary.
"""

def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0):
"""Builds the wrapper.

Args:
env: The environment to wrap.
discard_terminal_observation: Defaults to True. If True, the terminal
observation is discarded and the environment is reset immediately. The
returned observation will then be the start of the next episode. The
overridden observation is stored in `info["terminal_observation"]`.
If False, the terminal observation is returned and the environment is
reset in the next step.
reset_reward: The reward to return for the reset transition. Defaults to
0.0.
"""
super().__init__(env)
self.discard_terminal_observation = discard_terminal_observation
self.reset_reward = reset_reward
self.previous_done = False # Whether the previous step returned done=True.

def step(self, action):
"""When done=True, returns done=False, then reset depending on flag.

Depending on whether we are discarding the terminal observation,
either resets the environment and discards,
or returns the terminal observation, and then uses the next step to reset the
environment, after which steps will be performed as normal.
"""
if self.discard_terminal_observation:
return self._step_discard(action)
else:
return self._step_pad(action)

def _step_pad(self, action):
"""When done=True, return done=False instead and return the terminal obs.

The agent will then usually be asked to perform an action based on
the terminal observation. In the next step, this final action will be ignored
to instead reset the environment and return the initial observation of the new
episode.

Some potential caveats:
- The underlying environment will perform fewer steps than the wrapped
environment.
- The number of steps the agent performs and the number of steps recorded in the
underlying environment will not match, which could cause issues if these are
assumed to be the same.
"""
if self.previous_done:
self.previous_done = False
# This transition will only reset the environment, the action is ignored.
return self.env.reset(), self.reset_reward, False, {}

obs, rew, done, info = self.env.step(action)
if done:
self.previous_done = True
return obs, rew, False, info

def _step_discard(self, action):
"""When done=True, returns done=False instead and automatically resets.

When an automatic reset happens, the observation from reset is returned,
Expand Down
64 changes: 62 additions & 2 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,72 @@
from seals.testing import envs


def test_auto_reset_wrapper(episode_length=3, n_steps=100, n_manual_reset=2):
def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2):
"""Check that AutoResetWrapper returns correct values from step and reset.

AutoResetWrapper that pads trajectory with an extra transition containing the
terminal observations.
Also check that calls to .reset() do not interfere with automatic resets.
Due to the padding the number of steps counted inside the environment and the number
of steps performed outside the environment, i.e. the number of actions performed,
will differ. This test checks that this difference is consistent.
"""
env = util.AutoResetWrapper(envs.CountingEnv(episode_length=episode_length))
env = util.AutoResetWrapper(
envs.CountingEnv(episode_length=episode_length),
discard_terminal_observation=False,
)

for _ in range(n_manual_reset):
obs = env.reset()
assert obs == 0

# We count the number of episodes, so we can sanity check the padding.
num_episodes = 0
next_episode_end = episode_length
for t in range(1, n_steps + 1):
act = env.action_space.sample()
obs, rew, done, info = env.step(act)

# AutoResetWrapper overrides all done signals.
assert done is False

if t == next_episode_end:
# Unlike the AutoResetWrapper that discards terminal observations,
# here the final observation is returned directly, and is not stored
# in the info dict.
# Due to padding, for every episode the final observation is offset from
# the outer step by one.
assert obs == (t - num_episodes) / (num_episodes + 1)
assert rew == episode_length * 10
if t == next_episode_end + 1:
num_episodes += 1
# Because the final step returned the final observation, the initial
# obs of the next episode is returned in this additional step.
assert obs == 0
# Consequently, the next episode end is one step later, so it is
# episode_length steps from now.
next_episode_end = t + episode_length

# Reward of the 'reset transition' is fixed to be 0.
assert rew == 0

# Sanity check padding. Padding should be 1 for each past episode.
assert (
next_episode_end
== (num_episodes + 1) * episode_length + num_episodes
)


def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_reset=2):
"""Check that AutoResetWrapper returns correct values from step and reset.

Tests for AutoResetWrapper that discards terminal observations.
Also check that calls to .reset() do not interfere with automatic resets.
"""
env = util.AutoResetWrapper(
envs.CountingEnv(episode_length=episode_length),
discard_terminal_observation=True,
)

for _ in range(n_manual_reset):
obs = env.reset()
Expand Down