Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make RecordEpisodeStatistics work with VectorEnv #2296

Merged
58 changes: 38 additions & 20 deletions gym/wrappers/record_episode_statistics.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,57 @@
import time
from collections import deque
import numpy as np
import gym


class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.env_is_vec = isinstance(env, gym.vector.VectorEnv)
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = (
time.time()
) # TODO: use perf_counter when gym removes Python 2 support
self.episode_return = 0.0
self.episode_length = 0
self.episode_count = 0
self.episode_returns = None
self.episode_lengths = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)

def reset(self, **kwargs):
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_return = 0.0
self.episode_length = 0
return observation
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_envs is not defined here if env is a VectorEnv instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t follow. VectorEnv has a num_envs attribute, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right the wrapper inherits the properties from env, sorry my bad!

self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

def step(self, action):
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step(
action
)
self.episode_return += reward
self.episode_length += 1
if done:
info["episode"] = {
"r": self.episode_return,
"l": self.episode_length,
"t": round(time.time() - self.t0, 6),
}
self.return_queue.append(self.episode_return)
self.length_queue.append(self.episode_length)
self.episode_return = 0.0
self.episode_length = 0
return observation, reward, done, info
self.episode_returns += rewards
self.episode_lengths += 1
if not self.env_is_vec:
infos = [infos]
dones = [dones]
for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.time() - self.t0, 6),
}
infos[i]["episode"] = episode_info
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
return (
observations,
rewards,
dones if self.env_is_vec else dones[0],
infos if self.env_is_vec else infos[0],
)
18 changes: 16 additions & 2 deletions gym/wrappers/test_record_episode_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_record_episode_statistics(env_id, deque_size):

for n in range(5):
env.reset()
assert env.episode_return == 0.0
assert env.episode_length == 0
assert env.episode_returns[0] == 0.0
assert env.episode_lengths[0] == 0
for t in range(env.spec.max_episode_steps):
_, _, done, info = env.step(env.action_space.sample())
if done:
Expand All @@ -22,3 +22,17 @@ def test_record_episode_statistics(env_id, deque_size):
break
assert len(env.return_queue) == deque_size
assert len(env.length_queue) == deque_size


@pytest.mark.parametrize("num_envs", [1, 4])
def test_record_episode_statistics_with_vectorenv(env_id, num_envs):
envs = gym.vector.make("CartPole-v0", num_envs=num_envs, asynchronous=False)
envs = RecordEpisodeStatistics(envs)
envs.reset()
for _ in range(envs.env.envs[0].spec.max_episode_steps + 1):
_, _, dones, infos = envs.step(envs.action_space.sample())
for idx, info in enumerate(infos):
if dones[idx]:
assert "episode" in info
assert all([item in info["episode"] for item in ["r", "l", "t"]])
break