Skip to content

Commit

Permalink
[RLlib] New ConnectorV2 API #2: SingleAgentEpisode enhancements. (#41075
Browse files Browse the repository at this point in the history
)
  • Loading branch information
sven1977 authored Nov 30, 2023
1 parent 42c8e0b commit d6d2dee
Show file tree
Hide file tree
Showing 19 changed files with 2,613 additions and 1,300 deletions.
37 changes: 27 additions & 10 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,16 @@ py_test(
args = ["--dir=tuned_examples/ppo"]
)

py_test(
name = "test_memory_leak_ppo_new_stack",
tags = ["team:rllib", "memory_leak_tests"],
main = "utils/tests/run_memory_leak_tests.py",
size = "large",
srcs = ["utils/tests/run_memory_leak_tests.py"],
data = ["tuned_examples/ppo/memory-leak-test-ppo-new-stack.py"],
args = ["--dir=tuned_examples/ppo", "--to-check=rollout_worker"]
)

py_test(
name = "test_memory_leak_sac",
tags = ["team:rllib", "memory_leak_tests"],
Expand Down Expand Up @@ -772,12 +782,12 @@ py_test(
srcs = ["env/tests/test_multi_agent_env.py"]
)

py_test(
name = "env/tests/test_multi_agent_episode",
tags = ["team:rllib", "env"],
size = "medium",
srcs = ["env/tests/test_multi_agent_episode.py"]
)
# py_test(
# name = "env/tests/test_multi_agent_episode",
# tags = ["team:rllib", "env"],
# size = "medium",
# srcs = ["env/tests/test_multi_agent_episode.py"]
# )

sh_test(
name = "env/tests/test_remote_inference_cartpole",
Expand Down Expand Up @@ -818,19 +828,26 @@ sh_test(
# )

py_test(
name = "env/tests/test_single_agent_gym_env_runner",
name = "env/tests/test_single_agent_env_runner",
tags = ["team:rllib", "env"],
size = "medium",
srcs = ["env/tests/test_single_agent_gym_env_runner.py"]
srcs = ["env/tests/test_single_agent_env_runner.py"]
)

py_test(
name = "env/tests/test_single_agent_episode",
tags = ["team:rllib", "env"],
size = "medium",
size = "small",
srcs = ["env/tests/test_single_agent_episode.py"]
)

py_test(
name = "env/tests/test_lookback_buffer",
tags = ["team:rllib", "env"],
size = "small",
srcs = ["env/tests/test_lookback_buffer.py"]
)

py_test(
name = "env/wrappers/tests/test_exception_wrapper",
tags = ["team:rllib", "env"],
Expand Down Expand Up @@ -1332,7 +1349,6 @@ py_test(
# Tag: utils
# --------------------------------------------------------------------

# Checkpoint Utils
py_test(
name = "test_checkpoint_utils",
tags = ["team:rllib", "utils"],
Expand Down Expand Up @@ -2947,6 +2963,7 @@ py_test(
py_test_module_list(
files = [
"env/wrappers/tests/test_kaggle_wrapper.py",
"env/tests/test_multi_agent_episode.py",
"examples/env/tests/test_cliff_walking_wall_env.py",
"examples/env/tests/test_coin_game_non_vectorized_env.py",
"examples/env/tests/test_coin_game_vectorized_env.py",
Expand Down
79 changes: 41 additions & 38 deletions rllib/algorithms/dreamerv3/utils/env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
_, tf, _ = try_import_tf()


# TODO (sven): Use SingleAgentEnvRunner instead of this as soon as we have the new
# ConnectorV2 example classes to make Atari work properly with these (w/o requiring the
# classes at the bottom of this file here, e.g. `ActionClip`).
class DreamerV3EnvRunner(EnvRunner):
"""An environment runner to collect data from vectorized gymnasium environments."""

Expand Down Expand Up @@ -144,6 +147,7 @@ def __init__(

self._needs_initial_reset = True
self._episodes = [None for _ in range(self.num_envs)]
self._states = [None for _ in range(self.num_envs)]

# TODO (sven): Move metrics temp storage and collection out of EnvRunner
# and RolloutWorkers. These classes should not continue tracking some data
Expand Down Expand Up @@ -254,10 +258,8 @@ def _sample_timesteps(

# Set initial obs and states in the episodes.
for i in range(self.num_envs):
self._episodes[i].add_initial_observation(
initial_observation=obs[i],
initial_state={k: s[i] for k, s in states.items()},
)
self._episodes[i].add_env_reset(observation=obs[i])
self._states[i] = {k: s[i] for k, s in states.items()}
# Don't reset existing envs; continue in already started episodes.
else:
# Pick up stored observations and states from previous timesteps.
Expand All @@ -268,7 +270,9 @@ def _sample_timesteps(
states = {
k: np.stack(
[
initial_states[k][i] if eps.states is None else eps.states[k]
initial_states[k][i]
if self._states[i] is None
else self._states[i][k]
for i, eps in enumerate(self._episodes)
]
)
Expand All @@ -278,7 +282,7 @@ def _sample_timesteps(
# to 1.0, otherwise 0.0.
is_first = np.zeros((self.num_envs,))
for i, eps in enumerate(self._episodes):
if eps.states is None:
if len(eps) == 0:
is_first[i] = 1.0

# Loop through env for n timesteps.
Expand Down Expand Up @@ -319,37 +323,39 @@ def _sample_timesteps(
if terminateds[i] or truncateds[i]:
# Finish the episode with the actual terminal observation stored in
# the info dict.
self._episodes[i].add_timestep(
infos["final_observation"][i],
actions[i],
rewards[i],
state=s,
is_terminated=terminateds[i],
is_truncated=truncateds[i],
self._episodes[i].add_env_step(
observation=infos["final_observation"][i],
action=actions[i],
reward=rewards[i],
terminated=terminateds[i],
truncated=truncateds[i],
)
self._states[i] = s
# Reset h-states to the model's initial ones b/c we are starting a
# new episode.
for k, v in self.module.get_initial_state().items():
states[k][i] = v.numpy()
is_first[i] = True
done_episodes_to_return.append(self._episodes[i])
# Create a new episode object.
self._episodes[i] = SingleAgentEpisode(
observations=[obs[i]], states=s
)
self._episodes[i] = SingleAgentEpisode(observations=[obs[i]])
else:
self._episodes[i].add_timestep(
obs[i], actions[i], rewards[i], state=s
self._episodes[i].add_env_step(
observation=obs[i],
action=actions[i],
reward=rewards[i],
)
is_first[i] = False

self._states[i] = s

# Return done episodes ...
self._done_episodes_for_metrics.extend(done_episodes_to_return)
# ... and all ongoing episode chunks. Also, make sure, we return
# a copy and start new chunks so that callers of this function
# don't alter our ongoing and returned Episode objects.
ongoing_episodes = self._episodes
self._episodes = [eps.create_successor() for eps in self._episodes]
self._episodes = [eps.cut() for eps in self._episodes]
for eps in ongoing_episodes:
self._ongoing_episodes_for_metrics[eps.id_].append(eps)

Expand Down Expand Up @@ -385,10 +391,9 @@ def _sample_episodes(
render_images = [e.render() for e in self.env.envs]

for i in range(self.num_envs):
episodes[i].add_initial_observation(
initial_observation=obs[i],
initial_state={k: s[i] for k, s in states.items()},
initial_render_image=render_images[i],
episodes[i].add_env_reset(
observation=obs[i],
render_image=render_images[i],
)

eps = 0
Expand Down Expand Up @@ -419,19 +424,17 @@ def _sample_episodes(
render_images = [e.render() for e in self.env.envs]

for i in range(self.num_envs):
s = {k: s[i] for k, s in states.items()}
# The last entry in self.observations[i] is already the reset
# obs of the new episode.
if terminateds[i] or truncateds[i]:
eps += 1

episodes[i].add_timestep(
infos["final_observation"][i],
actions[i],
rewards[i],
state=s,
is_terminated=terminateds[i],
is_truncated=truncateds[i],
episodes[i].add_env_step(
observation=infos["final_observation"][i],
action=actions[i],
reward=rewards[i],
terminated=terminateds[i],
truncated=truncateds[i],
)
done_episodes_to_return.append(episodes[i])

Expand All @@ -448,15 +451,15 @@ def _sample_episodes(

episodes[i] = SingleAgentEpisode(
observations=[obs[i]],
states=s,
render_images=[render_images[i]],
render_images=(
[render_images[i]] if with_render_data else None
),
)
else:
episodes[i].add_timestep(
obs[i],
actions[i],
rewards[i],
state=s,
episodes[i].add_env_step(
observation=obs[i],
action=actions[i],
reward=rewards[i],
render_image=render_images[i],
)
is_first[i] = False
Expand Down
Loading

0 comments on commit d6d2dee

Please sign in to comment.