Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RecordEpisodeStatistics work with VectorEnv #2296

Merged
60 changes: 40 additions & 20 deletions gym/wrappers/record_episode_statistics.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,59 @@
import time
from collections import deque
import numpy as np
import gym


class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.env_is_vec = True
if not isinstance(env, gym.vector.VectorEnv):
self.num_envs = 1
self.env_is_vec = False
vwxyzjn marked this conversation as resolved.
Show resolved Hide resolved
self.t0 = (
time.time()
) # TODO: use perf_counter when gym removes Python 2 support
self.episode_return = 0.0
self.episode_length = 0
self.episode_count = 0
self.episode_returns = None
self.episode_lengths = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)

def reset(self, **kwargs):
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_return = 0.0
self.episode_length = 0
return observation
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_envs is not defined here if env is a VectorEnv instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t follow. VectorEnv has a num_envs attribute, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right the wrapper inherits the properties from env, sorry my bad!

self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

def step(self, action):
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step(
action
)
self.episode_return += reward
self.episode_length += 1
if done:
info["episode"] = {
"r": self.episode_return,
"l": self.episode_length,
"t": round(time.time() - self.t0, 6),
}
self.return_queue.append(self.episode_return)
self.length_queue.append(self.episode_length)
self.episode_return = 0.0
self.episode_length = 0
return observation, reward, done, info
self.episode_returns += rewards
self.episode_lengths += 1
if not self.env_is_vec:
infos = [infos]
dones = [dones]
for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.time() - self.t0, 6),
}
infos[i]["episode"] = episode_info
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
return (
observations,
rewards,
dones if self.env_is_vec else dones[0],
infos if self.env_is_vec else infos[0],
)
18 changes: 16 additions & 2 deletions gym/wrappers/test_record_episode_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_record_episode_statistics(env_id, deque_size):

for n in range(5):
env.reset()
assert env.episode_return == 0.0
assert env.episode_length == 0
assert env.episode_returns[0] == 0.0
assert env.episode_lengths[0] == 0
for t in range(env.spec.max_episode_steps):
_, _, done, info = env.step(env.action_space.sample())
if done:
Expand All @@ -22,3 +22,17 @@ def test_record_episode_statistics(env_id, deque_size):
break
assert len(env.return_queue) == deque_size
assert len(env.length_queue) == deque_size


@pytest.mark.parametrize("env_id", ["CartPole-v0"])
def test_record_episode_statistics_with_vectorenv(env_id):
envs = gym.vector.make(env_id, asynchronous=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the corresponding imports

from gym.vector import AsyncVectorEnv, SyncVectorEnv
from gym.vector.tests.utils import make_env
Suggested change
@pytest.mark.parametrize("env_id", ["CartPole-v0"])
def test_record_episode_statistics_with_vectorenv(env_id):
envs = gym.vector.make(env_id, asynchronous=False)
@pytest.mark.parametrize("klass", [SyncVectorEnv, AsyncVectorEnv])
@pytest.mark.parametrize("num_envs", [1, 4])
def test_record_episode_statistics_with_vectorenv(klass, num_envs):
env_fns = [make_env("CartPole-v0", i) for i in range(num_envs)]
envs = klass(env_fns)

Copy link
Contributor Author

@vwxyzjn vwxyzjn Aug 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately it’s gonna fail with AsyncVectorEnv because the envs.env.envs[0].spec.max_episode_steps is inaccessible. Maybe I should just hardcore a 201 instead of envs.env.envs[0].spec.max_episode_steps? Do we really need the test case with AsyncVectorEnv?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I didn't see that you were looping over that later. Then you can ignore this (maybe keeping the parametrization for num_envs?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Would you mind Allowing the GitHub action workflow runs? I have some weird setup That makes it difficult to run test cases locally….

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

envs = RecordEpisodeStatistics(envs)
envs.reset()
for _ in range(envs.env.envs[0].spec.max_episode_steps + 1):
_, _, dones, infos = envs.step(envs.action_space.sample())
for idx, info in enumerate(infos):
if dones[idx]:
assert "episode" in info
assert all([item in info["episode"] for item in ["r", "l", "t"]])
break