Skip to content

Commit

Permalink
Don't override terminal observation when using AutoResetWrapper (#69)
Browse files Browse the repository at this point in the history
* Add discard_terminal_observation flag to AutoResetWrapper

* Fix comments

* Add AutoResetWrapper test

* Typo

* Discard terminal obs by default, set reset reward

* Fix type error in base_envs

* Revert type error fix

* Pin sphinx-autodoc-typehints version

Co-authored-by: Daniel Filan <[email protected]>
  • Loading branch information
PavelCz and dfilan authored Jan 20, 2023
1 parent dff53ee commit de29873
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def get_readme() -> str:
]
DOCS_REQUIRE = [
"sphinx",
"sphinx-autodoc-typehints",
"sphinx-autodoc-typehints>=1.21.5",
"sphinx-rtd-theme",
]

Expand Down
67 changes: 66 additions & 1 deletion src/seals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,74 @@


class AutoResetWrapper(gym.Wrapper):
"""Hides done=True and auto-resets at the end of each episode."""
"""Hides done=True and auto-resets at the end of each episode.
Depending on the flag 'discard_terminal_observation', either discards the terminal
observation or pads with an additional 'reset transition'. The former is the default
behavior.
In the latter case, the action taken during the 'reset transition' will not have an
effect, the reward will be constant (set by the wrapper argument `reset_reward`,
which has default value 0.0), and info an empty dictionary.
"""

def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0):
"""Builds the wrapper.
Args:
env: The environment to wrap.
discard_terminal_observation: Defaults to True. If True, the terminal
observation is discarded and the environment is reset immediately. The
returned observation will then be the start of the next episode. The
overridden observation is stored in `info["terminal_observation"]`.
If False, the terminal observation is returned and the environment is
reset in the next step.
reset_reward: The reward to return for the reset transition. Defaults to
0.0.
"""
super().__init__(env)
self.discard_terminal_observation = discard_terminal_observation
self.reset_reward = reset_reward
self.previous_done = False # Whether the previous step returned done=True.

def step(self, action):
"""When done=True, returns done=False, then reset depending on flag.
Depending on whether we are discarding the terminal observation,
either resets the environment and discards,
or returns the terminal observation, and then uses the next step to reset the
environment, after which steps will be performed as normal.
"""
if self.discard_terminal_observation:
return self._step_discard(action)
else:
return self._step_pad(action)

def _step_pad(self, action):
"""When done=True, return done=False instead and return the terminal obs.
The agent will then usually be asked to perform an action based on
the terminal observation. In the next step, this final action will be ignored
to instead reset the environment and return the initial observation of the new
episode.
Some potential caveats:
- The underlying environment will perform fewer steps than the wrapped
environment.
- The number of steps the agent performs and the number of steps recorded in the
underlying environment will not match, which could cause issues if these are
assumed to be the same.
"""
if self.previous_done:
self.previous_done = False
# This transition will only reset the environment, the action is ignored.
return self.env.reset(), self.reset_reward, False, {}

obs, rew, done, info = self.env.step(action)
if done:
self.previous_done = True
return obs, rew, False, info

def _step_discard(self, action):
"""When done=True, returns done=False instead and automatically resets.
When an automatic reset happens, the observation from reset is returned,
Expand Down
64 changes: 62 additions & 2 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,72 @@
from seals.testing import envs


def test_auto_reset_wrapper(episode_length=3, n_steps=100, n_manual_reset=2):
def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2):
"""Check that AutoResetWrapper returns correct values from step and reset.
AutoResetWrapper that pads trajectory with an extra transition containing the
terminal observations.
Also check that calls to .reset() do not interfere with automatic resets.
Due to the padding the number of steps counted inside the environment and the number
of steps performed outside the environment, i.e. the number of actions performed,
will differ. This test checks that this difference is consistent.
"""
env = util.AutoResetWrapper(envs.CountingEnv(episode_length=episode_length))
env = util.AutoResetWrapper(
envs.CountingEnv(episode_length=episode_length),
discard_terminal_observation=False,
)

for _ in range(n_manual_reset):
obs = env.reset()
assert obs == 0

# We count the number of episodes, so we can sanity check the padding.
num_episodes = 0
next_episode_end = episode_length
for t in range(1, n_steps + 1):
act = env.action_space.sample()
obs, rew, done, info = env.step(act)

# AutoResetWrapper overrides all done signals.
assert done is False

if t == next_episode_end:
# Unlike the AutoResetWrapper that discards terminal observations,
# here the final observation is returned directly, and is not stored
# in the info dict.
# Due to padding, for every episode the final observation is offset from
# the outer step by one.
assert obs == (t - num_episodes) / (num_episodes + 1)
assert rew == episode_length * 10
if t == next_episode_end + 1:
num_episodes += 1
# Because the final step returned the final observation, the initial
# obs of the next episode is returned in this additional step.
assert obs == 0
# Consequently, the next episode end is one step later, so it is
# episode_length steps from now.
next_episode_end = t + episode_length

# Reward of the 'reset transition' is fixed to be 0.
assert rew == 0

# Sanity check padding. Padding should be 1 for each past episode.
assert (
next_episode_end
== (num_episodes + 1) * episode_length + num_episodes
)


def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_reset=2):
"""Check that AutoResetWrapper returns correct values from step and reset.
Tests for AutoResetWrapper that discards terminal observations.
Also check that calls to .reset() do not interfere with automatic resets.
"""
env = util.AutoResetWrapper(
envs.CountingEnv(episode_length=episode_length),
discard_terminal_observation=True,
)

for _ in range(n_manual_reset):
obs = env.reset()
Expand Down

0 comments on commit de29873

Please sign in to comment.