Skip to content

Commit

Permalink
Make RecordEpisodeStatistics work with VectorEnv (#2296)
Browse files Browse the repository at this point in the history
* Make RecordEpisodeStatistics work with VectorEnv

* fix test cases

* fix lint

* add test cases

* fix linting

* fix tests

* fix test cases...

* Update gym/wrappers/record_episode_statistics.py

Co-authored-by: Tristan Deleu <[email protected]>

* fix test cases

* fix test cases again

Co-authored-by: Tristan Deleu <[email protected]>
  • Loading branch information
vwxyzjn and tristandeleu authored Aug 5, 2021
1 parent ddd650a commit 1397e70
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
58 changes: 38 additions & 20 deletions gym/wrappers/record_episode_statistics.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,57 @@
import time
from collections import deque
import numpy as np
import gym


class RecordEpisodeStatistics(gym.Wrapper):
def __init__(self, env, deque_size=100):
super(RecordEpisodeStatistics, self).__init__(env)
self.env_is_vec = isinstance(env, gym.vector.VectorEnv)
self.num_envs = getattr(env, "num_envs", 1)
self.t0 = (
time.time()
) # TODO: use perf_counter when gym removes Python 2 support
self.episode_return = 0.0
self.episode_length = 0
self.episode_count = 0
self.episode_returns = None
self.episode_lengths = None
self.return_queue = deque(maxlen=deque_size)
self.length_queue = deque(maxlen=deque_size)

def reset(self, **kwargs):
observation = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_return = 0.0
self.episode_length = 0
return observation
observations = super(RecordEpisodeStatistics, self).reset(**kwargs)
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
return observations

def step(self, action):
observation, reward, done, info = super(RecordEpisodeStatistics, self).step(
observations, rewards, dones, infos = super(RecordEpisodeStatistics, self).step(
action
)
self.episode_return += reward
self.episode_length += 1
if done:
info["episode"] = {
"r": self.episode_return,
"l": self.episode_length,
"t": round(time.time() - self.t0, 6),
}
self.return_queue.append(self.episode_return)
self.length_queue.append(self.episode_length)
self.episode_return = 0.0
self.episode_length = 0
return observation, reward, done, info
self.episode_returns += rewards
self.episode_lengths += 1
if not self.env_is_vec:
infos = [infos]
dones = [dones]
for i in range(len(dones)):
if dones[i]:
infos[i] = infos[i].copy()
episode_return = self.episode_returns[i]
episode_length = self.episode_lengths[i]
episode_info = {
"r": episode_return,
"l": episode_length,
"t": round(time.time() - self.t0, 6),
}
infos[i]["episode"] = episode_info
self.return_queue.append(episode_return)
self.length_queue.append(episode_length)
self.episode_count += 1
self.episode_returns[i] = 0
self.episode_lengths[i] = 0
return (
observations,
rewards,
dones if self.env_is_vec else dones[0],
infos if self.env_is_vec else infos[0],
)
18 changes: 16 additions & 2 deletions gym/wrappers/test_record_episode_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_record_episode_statistics(env_id, deque_size):

for n in range(5):
env.reset()
assert env.episode_return == 0.0
assert env.episode_length == 0
assert env.episode_returns[0] == 0.0
assert env.episode_lengths[0] == 0
for t in range(env.spec.max_episode_steps):
_, _, done, info = env.step(env.action_space.sample())
if done:
Expand All @@ -22,3 +22,17 @@ def test_record_episode_statistics(env_id, deque_size):
break
assert len(env.return_queue) == deque_size
assert len(env.length_queue) == deque_size


@pytest.mark.parametrize("num_envs", [1, 4])
def test_record_episode_statistics_with_vectorenv(num_envs):
envs = gym.vector.make("CartPole-v0", num_envs=num_envs, asynchronous=False)
envs = RecordEpisodeStatistics(envs)
envs.reset()
for _ in range(envs.env.envs[0].spec.max_episode_steps + 1):
_, _, dones, infos = envs.step(envs.action_space.sample())
for idx, info in enumerate(infos):
if dones[idx]:
assert "episode" in info
assert all([item in info["episode"] for item in ["r", "l", "t"]])
break

0 comments on commit 1397e70

Please sign in to comment.