diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index 9ce1e7f086d..32e26d2a3f5 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -1,39 +1,57 @@ import time from collections import deque +import numpy as np import gym class RecordEpisodeStatistics(gym.Wrapper): def __init__(self, env, deque_size=100): super(RecordEpisodeStatistics, self).__init__(env) + self.env_is_vec = isinstance(env, gym.vector.VectorEnv) + self.num_envs = getattr(env, "num_envs", 1) self.t0 = ( time.time() ) # TODO: use perf_counter when gym removes Python 2 support - self.episode_return = 0.0 - self.episode_length = 0 + self.episode_count = 0 + self.episode_returns = None + self.episode_lengths = None self.return_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size) def reset(self, **kwargs): - observation = super(RecordEpisodeStatistics, self).reset(**kwargs) - self.episode_return = 0.0 - self.episode_length = 0 - return observation + observations = super(RecordEpisodeStatistics, self).reset(**kwargs) + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return observations def step(self, action): - observation, reward, done, info = super(RecordEpisodeStatistics, self).step( + observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step( action ) - self.episode_return += reward - self.episode_length += 1 - if done: - info["episode"] = { - "r": self.episode_return, - "l": self.episode_length, - "t": round(time.time() - self.t0, 6), - } - self.return_queue.append(self.episode_return) - self.length_queue.append(self.episode_length) - self.episode_return = 0.0 - self.episode_length = 0 - return observation, reward, done, info + self.episode_returns += rewards + self.episode_lengths += 1 + if not self.env_is_vec: + infos = [infos] + dones = [dones] + for i in range(len(dones)): + if dones[i]: + infos[i] = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = { + "r": episode_return, + "l": episode_length, + "t": round(time.time() - self.t0, 6), + } + infos[i]["episode"] = episode_info + self.return_queue.append(episode_return) + self.length_queue.append(episode_length) + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + return ( + observations, + rewards, + dones if self.env_is_vec else dones[0], + infos if self.env_is_vec else infos[0], + ) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 3aaf7bd1093..ce68eab4991 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -12,8 +12,8 @@ def test_record_episode_statistics(env_id, deque_size): for n in range(5): env.reset() - assert env.episode_return == 0.0 - assert env.episode_length == 0 + assert env.episode_returns[0] == 0.0 + assert env.episode_lengths[0] == 0 for t in range(env.spec.max_episode_steps): _, _, done, info = env.step(env.action_space.sample()) if done: @@ -22,3 +22,17 @@ def test_record_episode_statistics(env_id, deque_size): break assert len(env.return_queue) == deque_size assert len(env.length_queue) == deque_size + + +@pytest.mark.parametrize("num_envs", [1, 4]) +def test_record_episode_statistics_with_vectorenv(num_envs): + envs = gym.vector.make("CartPole-v0", num_envs=num_envs, asynchronous=False) + envs = RecordEpisodeStatistics(envs) + envs.reset() + for _ in range(envs.env.envs[0].spec.max_episode_steps + 1): + _, _, dones, infos = envs.step(envs.action_space.sample()) + for idx, info in enumerate(infos): + if dones[idx]: + assert "episode" in info + assert all([item in info["episode"] for item in ["r", "l", "t"]]) + break