From edd8b72123c6de7b11cb81d8bf5d1996e5dc9f14 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 4 Aug 2021 23:32:04 -0400 Subject: [PATCH 01/10] Make RecordEpisodeStatistics work with VectorEnv --- gym/wrappers/record_episode_statistics.py | 57 +++++++++++++---------- 1 file changed, 32 insertions(+), 25 deletions(-) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index 9ce1e7f086d..350a7be1964 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -1,39 +1,46 @@ import time from collections import deque +import numpy as np import gym class RecordEpisodeStatistics(gym.Wrapper): def __init__(self, env, deque_size=100): super(RecordEpisodeStatistics, self).__init__(env) - self.t0 = ( - time.time() - ) # TODO: use perf_counter when gym removes Python 2 support - self.episode_return = 0.0 - self.episode_length = 0 + self.env_is_vec = True + if not isinstance(env, gym.vector.VectorEnv): + self.num_envs = 1 + self.env_is_vec = False + self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support + self.episode_count = 0 + self.episode_returns = None + self.episode_lengths = None self.return_queue = deque(maxlen=deque_size) self.length_queue = deque(maxlen=deque_size) def reset(self, **kwargs): - observation = super(RecordEpisodeStatistics, self).reset(**kwargs) - self.episode_return = 0.0 - self.episode_length = 0 - return observation + observations = super(RecordEpisodeStatistics, self).reset(**kwargs) + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return observations def step(self, action): - observation, reward, done, info = super(RecordEpisodeStatistics, self).step( - action - ) - self.episode_return += reward - self.episode_length += 1 - if done: - info["episode"] = { - "r": self.episode_return, - "l": self.episode_length, - "t": round(time.time() - self.t0, 6), - } - self.return_queue.append(self.episode_return) - self.length_queue.append(self.episode_length) - self.episode_return = 0.0 - self.episode_length = 0 - return observation, reward, done, info + observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step(action) + self.episode_returns += rewards + self.episode_lengths += 1 + if not self.env_is_vec: + infos = [infos] + dones = [dones] + for i in range(len(dones)): + if dones[i]: + infos[i] = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t0, 6)} + infos[i]["episode"] = episode_info + self.return_queue.append(episode_return) + self.length_queue.append(episode_length) + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + return observations, rewards, dones if self.env_is_vec else dones[0], infos if self.env_is_vec else infos[0] From 14198222f39a1c8dd7a18e75c1d8abab4bf40600 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 00:13:20 -0400 Subject: [PATCH 02/10] fix test cases --- gym/wrappers/test_record_episode_statistics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 3aaf7bd1093..817cc0de488 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -12,8 +12,8 @@ def test_record_episode_statistics(env_id, deque_size): for n in range(5): env.reset() - assert env.episode_return == 0.0 - assert env.episode_length == 0 + assert env.episode_returns[0] == 0.0 + assert env.episode_lengths[0] == 0 for t in range(env.spec.max_episode_steps): _, _, done, info = env.step(env.action_space.sample()) if done: From aed31bf89de228a31cfe5c22464c3254716063ee Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 00:14:39 -0400 Subject: [PATCH 03/10] fix lint --- gym/wrappers/record_episode_statistics.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index 350a7be1964..ffa9dd3a28f 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -11,7 +11,9 @@ def __init__(self, env, deque_size=100): if not isinstance(env, gym.vector.VectorEnv): self.num_envs = 1 self.env_is_vec = False - self.t0 = time.time() # TODO: use perf_counter when gym removes Python 2 support + self.t0 = ( + time.time() + ) # TODO: use perf_counter when gym removes Python 2 support self.episode_count = 0 self.episode_returns = None self.episode_lengths = None @@ -25,7 +27,9 @@ def reset(self, **kwargs): return observations def step(self, action): - observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step(action) + observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step( + action + ) self.episode_returns += rewards self.episode_lengths += 1 if not self.env_is_vec: @@ -36,11 +40,20 @@ def step(self, action): infos[i] = infos[i].copy() episode_return = self.episode_returns[i] episode_length = self.episode_lengths[i] - episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t0, 6)} + episode_info = { + "r": episode_return, + "l": episode_length, + "t": round(time.time() - self.t0, 6), + } infos[i]["episode"] = episode_info self.return_queue.append(episode_return) self.length_queue.append(episode_length) self.episode_count += 1 self.episode_returns[i] = 0 self.episode_lengths[i] = 0 - return observations, rewards, dones if self.env_is_vec else dones[0], infos if self.env_is_vec else infos[0] + return ( + observations, + rewards, + dones if self.env_is_vec else dones[0], + infos if self.env_is_vec else infos[0], + ) From 0f25b42c35ccd1e4dd1bcd0738674e30029f966b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 00:21:03 -0400 Subject: [PATCH 04/10] add test cases --- gym/wrappers/test_record_episode_statistics.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 817cc0de488..e8389127dc1 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -22,3 +22,15 @@ def test_record_episode_statistics(env_id, deque_size): break assert len(env.return_queue) == deque_size assert len(env.length_queue) == deque_size + +@pytest.mark.parametrize("env_id", ["CartPole-v0"]) +def test_record_episode_statistics_with_vectorenv(env_id): + envs = gym.vector.make(env_id) + envs = RecordEpisodeStatistics(envs) + envs.reset() + for _ in range(envs.spec.max_episode_steps+1): + _, _, _, infos = envs.step(envs.action_space.sample()) + for info in infos: + assert "episode" in info + assert all([item in info["episode"] for item in ["r", "l", "t"]]) + break From 225e13083d65e7706c536d6983f2ff4d9262c5ef Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 00:21:37 -0400 Subject: [PATCH 05/10] fix linting --- gym/wrappers/test_record_episode_statistics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index e8389127dc1..4bc2518d442 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -23,12 +23,13 @@ def test_record_episode_statistics(env_id, deque_size): assert len(env.return_queue) == deque_size assert len(env.length_queue) == deque_size + @pytest.mark.parametrize("env_id", ["CartPole-v0"]) def test_record_episode_statistics_with_vectorenv(env_id): envs = gym.vector.make(env_id) envs = RecordEpisodeStatistics(envs) envs.reset() - for _ in range(envs.spec.max_episode_steps+1): + for _ in range(envs.spec.max_episode_steps + 1): _, _, _, infos = envs.step(envs.action_space.sample()) for info in infos: assert "episode" in info From d48114b5fb50b967f585373c7bc6127c4a243d6d Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 00:56:09 -0400 Subject: [PATCH 06/10] fix tests --- gym/wrappers/test_record_episode_statistics.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 4bc2518d442..409d9c4c7c9 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -29,9 +29,10 @@ def test_record_episode_statistics_with_vectorenv(env_id): envs = gym.vector.make(env_id) envs = RecordEpisodeStatistics(envs) envs.reset() - for _ in range(envs.spec.max_episode_steps + 1): - _, _, _, infos = envs.step(envs.action_space.sample()) - for info in infos: - assert "episode" in info - assert all([item in info["episode"] for item in ["r", "l", "t"]]) - break + for _ in range(envs.env.envs[0].spec.max_episode_steps + 1): + _, _, dones, infos = envs.step(envs.action_space.sample()) + for idx, info in enumerate(infos): + if dones[idx]: + assert "episode" in info + assert all([item in info["episode"] for item in ["r", "l", "t"]]) + break From a3e60ddbbef3691e38b922ff1041b9e590a323b8 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 09:30:28 -0400 Subject: [PATCH 07/10] fix test cases... --- gym/wrappers/test_record_episode_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 409d9c4c7c9..2db7b63eff1 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -26,7 +26,7 @@ def test_record_episode_statistics(env_id, deque_size): @pytest.mark.parametrize("env_id", ["CartPole-v0"]) def test_record_episode_statistics_with_vectorenv(env_id): - envs = gym.vector.make(env_id) + envs = gym.vector.make(env_id, asynchronous=False) envs = RecordEpisodeStatistics(envs) envs.reset() for _ in range(envs.env.envs[0].spec.max_episode_steps + 1): From e6df45278a1cd03778a3fcc64a50e943bf21c333 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 14:04:21 -0400 Subject: [PATCH 08/10] Update gym/wrappers/record_episode_statistics.py Co-authored-by: Tristan Deleu --- gym/wrappers/record_episode_statistics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/gym/wrappers/record_episode_statistics.py b/gym/wrappers/record_episode_statistics.py index ffa9dd3a28f..32e26d2a3f5 100644 --- a/gym/wrappers/record_episode_statistics.py +++ b/gym/wrappers/record_episode_statistics.py @@ -7,10 +7,8 @@ class RecordEpisodeStatistics(gym.Wrapper): def __init__(self, env, deque_size=100): super(RecordEpisodeStatistics, self).__init__(env) - self.env_is_vec = True - if not isinstance(env, gym.vector.VectorEnv): - self.num_envs = 1 - self.env_is_vec = False + self.env_is_vec = isinstance(env, gym.vector.VectorEnv) + self.num_envs = getattr(env, "num_envs", 1) self.t0 = ( time.time() ) # TODO: use perf_counter when gym removes Python 2 support From c97d946fb45e3cfd010d5b0802212c8919ed9f88 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 14:26:21 -0400 Subject: [PATCH 09/10] fix test cases --- gym/wrappers/test_record_episode_statistics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 2db7b63eff1..5e4922a9a6c 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -24,9 +24,9 @@ def test_record_episode_statistics(env_id, deque_size): assert len(env.length_queue) == deque_size -@pytest.mark.parametrize("env_id", ["CartPole-v0"]) -def test_record_episode_statistics_with_vectorenv(env_id): - envs = gym.vector.make(env_id, asynchronous=False) +@pytest.mark.parametrize("num_envs", [1, 4]) +def test_record_episode_statistics_with_vectorenv(env_id, num_envs): + envs = gym.vector.make("CartPole-v0", num_envs=num_envs, asynchronous=False) envs = RecordEpisodeStatistics(envs) envs.reset() for _ in range(envs.env.envs[0].spec.max_episode_steps + 1): From 72ac78d13d7cac6df54e354841229a6fb8f6499b Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Thu, 5 Aug 2021 16:29:56 -0400 Subject: [PATCH 10/10] fix test cases again --- gym/wrappers/test_record_episode_statistics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/wrappers/test_record_episode_statistics.py b/gym/wrappers/test_record_episode_statistics.py index 5e4922a9a6c..ce68eab4991 100644 --- a/gym/wrappers/test_record_episode_statistics.py +++ b/gym/wrappers/test_record_episode_statistics.py @@ -25,7 +25,7 @@ def test_record_episode_statistics(env_id, deque_size): @pytest.mark.parametrize("num_envs", [1, 4]) -def test_record_episode_statistics_with_vectorenv(env_id, num_envs): +def test_record_episode_statistics_with_vectorenv(num_envs): envs = gym.vector.make("CartPole-v0", num_envs=num_envs, asynchronous=False) envs = RecordEpisodeStatistics(envs) envs.reset()