Skip to content

Commit

Permalink
add episodic trigger (#2376)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Aug 31, 2021
1 parent 1504041 commit a004558
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 15 deletions.
38 changes: 33 additions & 5 deletions docs/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,18 @@ Gym includes numerous wrappers for environments that include preprocessing and v
`RecordEpisodeStatistic(env)` [text]
* Needs review (including for good assertion messages and test coverage)

`RecordVideo(env, video_folder, record_video_trigger, video_length=0, name_prefix="rl-video")` [text]
`RecordVideo(env, video_folder, episode_trigger, step_trigger, video_length=0, name_prefix="rl-video")` [text]

The `RecordVideo` is a lightweight `gym.Wrapper` that helps recording videos. See the following
code as an example.

```python
import gym
from gym.wrappers import RecordVideo, capped_cubic_video_schedule
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", record_video_trigger=lambda x: x % 100 == 0)
env = RecordVideo(env, "videos")
# the above is equivalent as
# env = RecordVideo(env, "videos", episode_trigger=capped_cubic_video_schedule)
observation = env.reset()
for _ in range(1000):
env.render()
Expand All @@ -89,10 +92,35 @@ for _ in range(1000):
env.close()
```

To use it, you need to specify the `video_folder` as the storing location and
`record_video_trigger` as a frequency at which you want to record.
To use it, you need to specify the `video_folder` as the storing location. By default
the `RecordVideo` uses episode counts to trigger video recording based on the `episode_trigger=capped_cubic_video_schedule`,
which is a cubic progression for early episodes (1,8,27,...) and then every 1000 episodes (1000, 2000, 3000...).
This can be changed by modifying the `episode_trigger` argument of the `RecordVideo`).

There are two modes of video the recording:
Alternatively, you may also trigger the the video recording based on the environment steps via the
`step_trigger` like

```python
import gym
from gym.wrappers import RecordVideo
env = gym.make("CartPole-v1")
env = RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
observation = env.reset()
for _ in range(1000):
env.render()
action = env.action_space.sample() # your agent here (this takes random actions)
observation, reward, done, info = env.step(action)

if done:
observation = env.reset()
env.close()
```

Which will trigger the video recording at exactly every 100 environment steps (unless the previous recording hasn't finished yet).

Note that you may use exactly one trigger (i.e. `step_trigger` or `record_video_trigger`) at a time.

There are two modes to end the video recording:
1. Episodic mode.
* By default `video_length=0` means the wrapper will record *episodic* videos: it will keep
record the frames until the env returns `done=True`.
Expand Down
2 changes: 1 addition & 1 deletion gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from gym.wrappers.resize_observation import ResizeObservation
from gym.wrappers.clip_action import ClipAction
from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
from gym.wrappers.record_video import RecordVideo
from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
40 changes: 36 additions & 4 deletions gym/wrappers/record_video.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
import os
import gym
from typing import Callable
import warnings

from gym.wrappers.monitoring import video_recorder


def capped_cubic_video_schedule(episode_id):
if episode_id < 1000:
return int(round(episode_id ** (1.0 / 3))) ** 3 == episode_id
else:
return episode_id % 1000 == 0


class RecordVideo(gym.Wrapper):
def __init__(
self,
env,
video_folder: str,
record_video_trigger: Callable[[int], bool],
episode_trigger: Callable[[int], bool] = None,
step_trigger: Callable[[int], bool] = None,
video_length: int = 0,
name_prefix: str = "rl-video",
):
super(RecordVideo, self).__init__(env)
self.record_video_trigger = record_video_trigger

if episode_trigger is None and step_trigger is None:
episode_trigger = capped_cubic_video_schedule

trigger_count = sum([x is not None for x in [episode_trigger, step_trigger]])
assert trigger_count == 1, "Must specify exactly one trigger"

self.episode_trigger = episode_trigger
self.step_trigger = step_trigger
self.video_recorder = None

self.video_folder = os.path.abspath(video_folder)
# Create output folder if needed
if os.path.isdir(self.video_folder):
warnings.warn(
f"Overwriting existing videos at {self.video_folder} folder (try specifying a different `video_folder` for the `RecordVideo` wrapper if this is not desired)"
)
os.makedirs(self.video_folder, exist_ok=True)

self.name_prefix = name_prefix
Expand All @@ -29,6 +50,7 @@ def __init__(
self.recording = False
self.recorded_frames = 0
self.is_vector_env = getattr(env, "is_vector_env", False)
self.episode_id = 0

def reset(self, **kwargs):
observations = super(RecordVideo, self).reset(**kwargs)
Expand All @@ -40,17 +62,25 @@ def start_video_recorder(self):
self.close_video_recorder()

video_name = f"{self.name_prefix}-step-{self.step_id}"
if self.episode_trigger:
video_name = f"{self.name_prefix}-episode-{self.episode_id}"

base_path = os.path.join(self.video_folder, video_name)
self.video_recorder = video_recorder.VideoRecorder(
env=self.env, base_path=base_path, metadata={"step_id": self.step_id}
env=self.env,
base_path=base_path,
metadata={"step_id": self.step_id, "episode_id": self.episode_id},
)

self.video_recorder.capture_frame()
self.recorded_frames = 1
self.recording = True

def _video_enabled(self):
return self.record_video_trigger(self.step_id)
if self.step_trigger:
return self.step_trigger(self.step_id)
else:
return self.episode_trigger(self.step_id)

def step(self, action):
observations, rewards, dones, infos = super(RecordVideo, self).step(action)
Expand All @@ -65,8 +95,10 @@ def step(self, action):
else:
if not self.is_vector_env:
if dones:
self.episode_id += 1
self.close_video_recorder()
elif dones[0]:
self.episode_id += 1
self.close_video_recorder()

elif self._video_enabled():
Expand Down
24 changes: 19 additions & 5 deletions gym/wrappers/test_record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,25 @@
from gym.wrappers import RecordEpisodeStatistics, RecordVideo


def test_record_video():
def test_record_video_using_default_trigger():
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(
env, "videos", record_video_trigger=lambda x: x % 100 == 0
)
env = gym.wrappers.RecordVideo(env, "videos")
env.reset()
for _ in range(199):
action = env.action_space.sample()
_, _, done, _ = env.step(action)
if done:
env.reset()
env.close()
assert os.path.isdir("videos")
mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
assert len(mp4_files) == env.episode_id
shutil.rmtree("videos")


def test_record_video_step_trigger():
env = gym.make("CartPole-v1")
env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
env.reset()
for _ in range(199):
action = env.action_space.sample()
Expand All @@ -32,7 +46,7 @@ def thunk():
env.observation_space.seed(seed)
if seed == 1:
env = gym.wrappers.RecordVideo(
env, "videos", record_video_trigger=lambda x: x % 100 == 0
env, "videos", step_trigger=lambda x: x % 100 == 0
)
return env

Expand Down

0 comments on commit a004558

Please sign in to comment.