From d6d2deef5278812a8f93e4a7d982d51d7b53b51a Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 30 Nov 2023 13:14:41 +0100 Subject: [PATCH] [RLlib] New ConnectorV2 API #02: SingleAgentEpisode enhancements. (#41075) --- rllib/BUILD | 37 +- .../algorithms/dreamerv3/utils/env_runner.py | 79 +- rllib/env/multi_agent_episode.py | 443 +++--- rllib/env/single_agent_env_runner.py | 178 ++- rllib/env/single_agent_episode.py | 1291 +++++++++++++---- rllib/env/testing/__init__.py | 0 .../testing/single_agent_gym_env_runner.py | 243 ---- rllib/env/tests/test_lookback_buffer.py | 491 +++++++ rllib/env/tests/test_multi_agent_episode.py | 242 +-- ...ner.py => test_single_agent_env_runner.py} | 50 +- rllib/env/tests/test_single_agent_episode.py | 418 +++--- rllib/env/utils.py | 335 ++++- rllib/evaluation/postprocessing_v2.py | 6 +- .../ppo/memory-leak-test-ppo-new-stack.py | 17 + rllib/utils/numpy.py | 24 +- .../tests/test_episode_replay_buffer.py | 4 +- rllib/utils/spaces/space_utils.py | 36 +- rllib/utils/spaces/tests/test_space_utils.py | 6 + rllib/utils/tests/run_memory_leak_tests.py | 13 +- 19 files changed, 2613 insertions(+), 1300 deletions(-) delete mode 100644 rllib/env/testing/__init__.py delete mode 100644 rllib/env/testing/single_agent_gym_env_runner.py create mode 100644 rllib/env/tests/test_lookback_buffer.py rename rllib/env/tests/{test_single_agent_gym_env_runner.py => test_single_agent_env_runner.py} (54%) create mode 100644 rllib/tuned_examples/ppo/memory-leak-test-ppo-new-stack.py diff --git a/rllib/BUILD b/rllib/BUILD index 8d8fc3f73244..977b56be70ac 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -683,6 +683,16 @@ py_test( args = ["--dir=tuned_examples/ppo"] ) +py_test( + name = "test_memory_leak_ppo_new_stack", + tags = ["team:rllib", "memory_leak_tests"], + main = "utils/tests/run_memory_leak_tests.py", + size = "large", + srcs = ["utils/tests/run_memory_leak_tests.py"], + data = ["tuned_examples/ppo/memory-leak-test-ppo-new-stack.py"], + args = ["--dir=tuned_examples/ppo", "--to-check=rollout_worker"] +) + py_test( name = "test_memory_leak_sac", tags = ["team:rllib", "memory_leak_tests"], @@ -772,12 +782,12 @@ py_test( srcs = ["env/tests/test_multi_agent_env.py"] ) -py_test( - name = "env/tests/test_multi_agent_episode", - tags = ["team:rllib", "env"], - size = "medium", - srcs = ["env/tests/test_multi_agent_episode.py"] -) +# py_test( +# name = "env/tests/test_multi_agent_episode", +# tags = ["team:rllib", "env"], +# size = "medium", +# srcs = ["env/tests/test_multi_agent_episode.py"] +# ) sh_test( name = "env/tests/test_remote_inference_cartpole", @@ -818,19 +828,26 @@ sh_test( # ) py_test( - name = "env/tests/test_single_agent_gym_env_runner", + name = "env/tests/test_single_agent_env_runner", tags = ["team:rllib", "env"], size = "medium", - srcs = ["env/tests/test_single_agent_gym_env_runner.py"] + srcs = ["env/tests/test_single_agent_env_runner.py"] ) py_test( name = "env/tests/test_single_agent_episode", tags = ["team:rllib", "env"], - size = "medium", + size = "small", srcs = ["env/tests/test_single_agent_episode.py"] ) +py_test( + name = "env/tests/test_lookback_buffer", + tags = ["team:rllib", "env"], + size = "small", + srcs = ["env/tests/test_lookback_buffer.py"] +) + py_test( name = "env/wrappers/tests/test_exception_wrapper", tags = ["team:rllib", "env"], @@ -1332,7 +1349,6 @@ py_test( # Tag: utils # -------------------------------------------------------------------- -# Checkpoint Utils py_test( name = "test_checkpoint_utils", tags = ["team:rllib", "utils"], @@ -2947,6 +2963,7 @@ py_test( py_test_module_list( files = [ "env/wrappers/tests/test_kaggle_wrapper.py", + "env/tests/test_multi_agent_episode.py", "examples/env/tests/test_cliff_walking_wall_env.py", "examples/env/tests/test_coin_game_non_vectorized_env.py", "examples/env/tests/test_coin_game_vectorized_env.py", diff --git a/rllib/algorithms/dreamerv3/utils/env_runner.py b/rllib/algorithms/dreamerv3/utils/env_runner.py index 259a27e4f7df..0fcf9a746235 100644 --- a/rllib/algorithms/dreamerv3/utils/env_runner.py +++ b/rllib/algorithms/dreamerv3/utils/env_runner.py @@ -32,6 +32,9 @@ _, tf, _ = try_import_tf() +# TODO (sven): Use SingleAgentEnvRunner instead of this as soon as we have the new +# ConnectorV2 example classes to make Atari work properly with these (w/o requiring the +# classes at the bottom of this file here, e.g. `ActionClip`). class DreamerV3EnvRunner(EnvRunner): """An environment runner to collect data from vectorized gymnasium environments.""" @@ -144,6 +147,7 @@ def __init__( self._needs_initial_reset = True self._episodes = [None for _ in range(self.num_envs)] + self._states = [None for _ in range(self.num_envs)] # TODO (sven): Move metrics temp storage and collection out of EnvRunner # and RolloutWorkers. These classes should not continue tracking some data @@ -254,10 +258,8 @@ def _sample_timesteps( # Set initial obs and states in the episodes. for i in range(self.num_envs): - self._episodes[i].add_initial_observation( - initial_observation=obs[i], - initial_state={k: s[i] for k, s in states.items()}, - ) + self._episodes[i].add_env_reset(observation=obs[i]) + self._states[i] = {k: s[i] for k, s in states.items()} # Don't reset existing envs; continue in already started episodes. else: # Pick up stored observations and states from previous timesteps. @@ -268,7 +270,9 @@ def _sample_timesteps( states = { k: np.stack( [ - initial_states[k][i] if eps.states is None else eps.states[k] + initial_states[k][i] + if self._states[i] is None + else self._states[i][k] for i, eps in enumerate(self._episodes) ] ) @@ -278,7 +282,7 @@ def _sample_timesteps( # to 1.0, otherwise 0.0. is_first = np.zeros((self.num_envs,)) for i, eps in enumerate(self._episodes): - if eps.states is None: + if len(eps) == 0: is_first[i] = 1.0 # Loop through env for n timesteps. @@ -319,14 +323,14 @@ def _sample_timesteps( if terminateds[i] or truncateds[i]: # Finish the episode with the actual terminal observation stored in # the info dict. - self._episodes[i].add_timestep( - infos["final_observation"][i], - actions[i], - rewards[i], - state=s, - is_terminated=terminateds[i], - is_truncated=truncateds[i], + self._episodes[i].add_env_step( + observation=infos["final_observation"][i], + action=actions[i], + reward=rewards[i], + terminated=terminateds[i], + truncated=truncateds[i], ) + self._states[i] = s # Reset h-states to the model's initial ones b/c we are starting a # new episode. for k, v in self.module.get_initial_state().items(): @@ -334,22 +338,24 @@ def _sample_timesteps( is_first[i] = True done_episodes_to_return.append(self._episodes[i]) # Create a new episode object. - self._episodes[i] = SingleAgentEpisode( - observations=[obs[i]], states=s - ) + self._episodes[i] = SingleAgentEpisode(observations=[obs[i]]) else: - self._episodes[i].add_timestep( - obs[i], actions[i], rewards[i], state=s + self._episodes[i].add_env_step( + observation=obs[i], + action=actions[i], + reward=rewards[i], ) is_first[i] = False + self._states[i] = s + # Return done episodes ... self._done_episodes_for_metrics.extend(done_episodes_to_return) # ... and all ongoing episode chunks. Also, make sure, we return # a copy and start new chunks so that callers of this function # don't alter our ongoing and returned Episode objects. ongoing_episodes = self._episodes - self._episodes = [eps.create_successor() for eps in self._episodes] + self._episodes = [eps.cut() for eps in self._episodes] for eps in ongoing_episodes: self._ongoing_episodes_for_metrics[eps.id_].append(eps) @@ -385,10 +391,9 @@ def _sample_episodes( render_images = [e.render() for e in self.env.envs] for i in range(self.num_envs): - episodes[i].add_initial_observation( - initial_observation=obs[i], - initial_state={k: s[i] for k, s in states.items()}, - initial_render_image=render_images[i], + episodes[i].add_env_reset( + observation=obs[i], + render_image=render_images[i], ) eps = 0 @@ -419,19 +424,17 @@ def _sample_episodes( render_images = [e.render() for e in self.env.envs] for i in range(self.num_envs): - s = {k: s[i] for k, s in states.items()} # The last entry in self.observations[i] is already the reset # obs of the new episode. if terminateds[i] or truncateds[i]: eps += 1 - episodes[i].add_timestep( - infos["final_observation"][i], - actions[i], - rewards[i], - state=s, - is_terminated=terminateds[i], - is_truncated=truncateds[i], + episodes[i].add_env_step( + observation=infos["final_observation"][i], + action=actions[i], + reward=rewards[i], + terminated=terminateds[i], + truncated=truncateds[i], ) done_episodes_to_return.append(episodes[i]) @@ -448,15 +451,15 @@ def _sample_episodes( episodes[i] = SingleAgentEpisode( observations=[obs[i]], - states=s, - render_images=[render_images[i]], + render_images=( + [render_images[i]] if with_render_data else None + ), ) else: - episodes[i].add_timestep( - obs[i], - actions[i], - rewards[i], - state=s, + episodes[i].add_env_step( + observation=obs[i], + action=actions[i], + reward=rewards[i], render_image=render_images[i], ) is_first[i] = False diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 873a9db1c26a..50add0b4aaad 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -1,8 +1,9 @@ -import numpy as np -import uuid - +from collections import defaultdict from queue import Queue from typing import Any, Dict, List, Optional, Set, Union +import uuid + +import numpy as np from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.policy.sample_batch import MultiAgentBatch @@ -29,14 +30,14 @@ def __init__( agent_episode_ids: Optional[Dict[str, str]] = None, *, observations: Optional[List[MultiAgentDict]] = None, + infos: Optional[List[MultiAgentDict]] = None, actions: Optional[List[MultiAgentDict]] = None, rewards: Optional[List[MultiAgentDict]] = None, - infos: Optional[List[MultiAgentDict]] = None, - t_started: Optional[int] = None, - is_terminated: Union[List[MultiAgentDict], bool] = False, - is_truncated: Union[List[MultiAgentDict], bool] = False, + terminateds: Union[MultiAgentDict, bool] = False, + truncateds: Union[MultiAgentDict, bool] = False, render_images: Optional[List[np.ndarray]] = None, extra_model_outputs: Optional[List[MultiAgentDict]] = None, + t_started: Optional[int] = None, ) -> "MultiAgentEpisode": """Initializes a `MultiAgentEpisode`. @@ -45,43 +46,49 @@ def __init__( If None, a hexadecimal id is created. In case of providing a string, make sure that it is unique, as episodes get concatenated via this string. - agent_ids: Obligatory. A list of strings containing the agent ids. + agent_ids: A list of strings containing the agent ids. These have to be provided at initialization. - agent_episode_ids: Optional. Either a dictionary mapping agent ids + agent_episode_ids: Either a dictionary mapping agent ids corresponding `SingleAgentEpisode` or None. If None, each `SingleAgentEpisode` in `MultiAgentEpisode.agent_episodes` will generate a hexadecimal code. If a dictionary is provided make sure that ids are unique as agents' `SingleAgentEpisode`s get concatenated or recreated by it. - observations: A dictionary mapping from agent ids to observations. - Can be None. If provided, it should be provided together with - all other episode data (actions, rewards, etc.) - actions: A dictionary mapping from agent ids to corresponding - actions. Can be None. If provided, it should be provided - together with all other episode data (observations, rewards, - etc.). - rewards: A dictionary mapping from agent ids to corresponding rewards. - Can be None. If provided, it should be provided together with - all other episode data (observations, rewards, etc.). - infos: A dictionary mapping from agent ids to corresponding infos. - Can be None. If provided, it should be provided together with - all other episode data (observations, rewards, etc.). - t_started: Optional. An unsigned int that defines the starting point - of the episode. This is only different from zero, if an ongoing - episode is created. - is_terminazted: Optional. A boolean defining, if an environment has - terminated. The default is `False`, i.e. the episode is ongoing. - is_truncated: Optional. A boolean, defining, if an environment is - truncated. The default is `False`, i.e. the episode is ongoing. - render_images: Optional. A list of RGB uint8 images from rendering - the environment. - extra_model_outputs: Optional. A dictionary mapping agent ids to their - corresponding extra model outputs. Each of the latter is a list of - dictionaries containing specific model outputs for the algorithm - used (e.g. `vf_preds` and `action_logp` for PPO) from a rollout. - If data is provided it should be complete (i.e. observations, - actions, rewards, is_terminated, is_truncated, and all necessary - `extra_model_outputs`). + observations: A list of dictionaries mapping agent IDs to observations. + Can be None. If provided, should match all other episode data + (actions, rewards, etc.) in terms of list lengths and agent IDs. + infos: A list of dictionaries mapping agent IDs to info dicts. + Can be None. If provided, should match all other episode data + (observations, rewards, etc.) in terms of list lengths and agent IDs. + actions: A list of dictionaries mapping agent IDs to actions. + Can be None. If provided, should match all other episode data + (observations, rewards, etc.) in terms of list lengths and agent IDs. + rewards: A list of dictionaries mapping agent IDs to rewards. + Can be None. If provided, should match all other episode data + (actions, rewards, etc.) in terms of list lengths and agent IDs. + terminateds: A boolean defining if an environment has + terminated OR a MultiAgentDict mapping individual agent ids + to boolean flags indicating whether individual agents have terminated. + A special __all__ key in these dicts indicates, whether the episode + is terminated for all agents. + The default is `False`, i.e. the episode has not been terminated. + truncateds: A boolean defining if the environment has been + truncated OR a MultiAgentDict mapping individual agent ids + to boolean flags indicating whether individual agents have been + truncated. A special __all__ key in these dicts indicates, whether the + episode is truncated for all agents. + The default is `False`, i.e. the episode has not been truncated. + render_images: A list of RGB uint8 images from rendering + the multi-agent environment. + extra_model_outputs: A list of dictionaries mapping agent IDs to their + corresponding extra model outputs. Each of these "outputs" is a dict + mapping keys (str) to model output values, for example for + `key=STATE_OUT`, the values would be the internal state outputs for + that agent. + t_started: The timestep (int) that defines the starting point + of the episode. This is only larger zero, if an ongoing episode is + created, for example by slicing an ongoing episode or by calling + the `cut()` method on an ongoing episode. """ self.id_: str = id_ or uuid.uuid4().hex @@ -147,16 +154,16 @@ def __init__( # If this is an ongoing episode than the last `__all__` should be `False` self.is_terminated: bool = ( - is_terminated - if isinstance(is_terminated, bool) - else is_terminated[-1]["__all__"] + terminateds + if isinstance(terminateds, bool) + else terminateds.get("__all__", False) ) # If this is an ongoing episode than the last `__all__` should be `False` self.is_truncated: bool = ( - is_truncated - if isinstance(is_truncated, bool) - else is_truncated[-1]["__all__"] + truncateds + if isinstance(truncateds, bool) + else truncateds.get("__all__", False) ) # Note that all attributes will be recorded along the global timestep @@ -169,8 +176,8 @@ def __init__( actions, rewards, infos, - is_terminated, - is_truncated, + terminateds, + truncateds, extra_model_outputs, ) for agent_id in self._agent_ids @@ -773,30 +780,27 @@ def get_truncateds(self) -> MultiAgentDict: truncateds.update({"__all__": self.is_terminated}) return truncateds - def add_initial_observation( + def add_env_reset( self, *, - initial_observation: MultiAgentDict, - initial_info: Optional[MultiAgentDict] = None, - initial_render_image: Optional[np.ndarray] = None, + observations: MultiAgentDict, + infos: Optional[MultiAgentDict] = None, + render_image: Optional[np.ndarray] = None, ) -> None: """Stores initial observation. Args: - initial_observation: A dictionary mapping agent ids - to initial observations. Note that some agents may not have an initial - observation. - initial_info: A dictionary mapping agent ids to initial - infos. Note that some agents may not have an initial - info dict. - initial_render_image: An RGB uint8 image from rendering the - environment. + observations: A dictionary mapping agent ids to initial observations. + Note that some agents may not have an initial observation. + infos: A dictionary mapping agent ids to initial info dicts. + Note that some agents may not have an initial info dict. + render_image: An RGB uint8 image from rendering the environment. """ assert not self.is_done # Assume that this episode is completely empty and has not stepped yet. # Leave self.t (and self.t_started) at 0. assert self.t == self.t_started == 0 - initial_info = initial_info or {} + infos = infos or {} # TODO (simon): After clearing with sven for initialization of timesteps # this might be removed. @@ -807,66 +811,65 @@ def add_initial_observation( # Note that we store the render images into the `MultiAgentEpisode` # instead into each `SingleAgentEpisode`. - if initial_render_image is not None: - self.render_images.append(initial_render_image) + if render_image is not None: + self.render_images.append(render_image) # Note, all agents will have an initial observation. - for agent_id in initial_observation.keys(): + for agent_id in observations.keys(): # Add initial timestep for each agent to the timestep mapping. self.global_t_to_local_t[agent_id].append(self.t) # Add initial observations to the agent's episode. - self.agent_episodes[agent_id].add_initial_observation( + self.agent_episodes[agent_id].add_env_reset( # Note, initial observation has to be provided. - initial_observation=initial_observation[agent_id], - initial_info=initial_info.get(agent_id), - initial_state=None, + observation=observations[agent_id], + infos=infos.get(agent_id), ) - def add_timestep( + def add_env_step( self, - observation: MultiAgentDict, - action: MultiAgentDict, - reward: MultiAgentDict, + observations: MultiAgentDict, + actions: MultiAgentDict, + rewards: MultiAgentDict, *, - info: Optional[MultiAgentDict] = None, - is_terminated: Optional[MultiAgentDict] = None, - is_truncated: Optional[MultiAgentDict] = None, + infos: Optional[MultiAgentDict] = None, + terminateds: Optional[MultiAgentDict] = None, + truncateds: Optional[MultiAgentDict] = None, render_image: Optional[np.ndarray] = None, - extra_model_output: Optional[MultiAgentDict] = None, + extra_model_outputs: Optional[MultiAgentDict] = None, ) -> None: """Adds a timestep to the episode. Args: - observation: A dictionary mapping agent ids to their corresponding + observations: A dictionary mapping agent ids to their corresponding observations. Note that some agents may not have stepped at this timestep. - action: Mandatory. A dictionary mapping agent ids to their + actions: Mandatory. A dictionary mapping agent ids to their corresponding actions. Note that some agents may not have stepped at this timestep. - reward: Mandatory. A dictionary mapping agent ids to their + rewards: Mandatory. A dictionary mapping agent ids to their corresponding observations. Note that some agents may not have stepped at this timestep. - info: A dictionary mapping agent ids to their + infos: A dictionary mapping agent ids to their corresponding info. Note that some agents may not have stepped at this timestep. - is_terminated: A dictionary mapping agent ids to their `terminated` flags, + terminateds: A dictionary mapping agent ids to their `terminated` flags, indicating, whether the environment has been terminated for them. A special `__all__` key indicates that the episode is terminated for all agent ids. - is_terminated: A dictionary mapping agent ids to their `truncated` flags, + terminateds: A dictionary mapping agent ids to their `truncated` flags, indicating, whether the environment has been truncated for them. A special `__all__` key indicates that the episode is `truncated` for all agent ids. render_image: An RGB uint8 image from rendering the environment. - extra_model_output: Optional. A dictionary mapping agent ids to their + extra_model_outputs: Optional. A dictionary mapping agent ids to their corresponding specific model outputs (also in a dictionary; e.g. `vf_preds` for PPO). """ # Cannot add data to an already done episode. assert not self.is_done - is_terminated = is_terminated or {} - is_truncated = is_truncated or {} + terminateds = terminateds or {} + truncateds = truncateds or {} # Environment step. self.t += 1 @@ -875,8 +878,8 @@ def add_timestep( # terminated or truncated? # TODO (simon): Maybe allow user to not provide this and then `__all__` is # False? - self.is_terminated = is_terminated.get("__all__", False) - self.is_truncated = is_truncated.get("__all__", False) + self.is_terminated = terminateds.get("__all__", False) + self.is_truncated = truncateds.get("__all__", False) # Note that we store the render images into the `MultiAgentEpisode` # instead of storing them into each `SingleAgentEpisode`. @@ -889,15 +892,13 @@ def add_timestep( if self.agent_episodes[agent_id].is_done: continue - agent_is_terminated = ( - is_terminated.get(agent_id, False) or self.is_terminated - ) - agent_is_truncated = is_truncated.get(agent_id, False) or self.is_truncated + agent_is_terminated = terminateds.get(agent_id, False) or self.is_terminated + agent_is_truncated = truncateds.get(agent_id, False) or self.is_truncated # CASE 1: observation, no action. # If we have an observation, but no action, we might have a buffered action, # or an initial agent observation. - if agent_id in observation and agent_id not in action: + if agent_id in observations and agent_id not in actions: # We have a buffered action. if self.agent_buffers[agent_id]["actions"].full(): # Get the action from the buffer. @@ -913,10 +914,10 @@ def add_timestep( # default of zero reward. agent_reward = self.agent_buffers[agent_id]["rewards"].get_nowait() # We might also got some reward in this episode. - if agent_id in reward: - agent_reward += reward[agent_id] + if agent_id in rewards: + agent_reward += rewards[agent_id] # Also add to the global reward list. - self.partial_rewards[agent_id].append(reward[agent_id]) + self.partial_rewards[agent_id].append(rewards[agent_id]) # And add to the global reward timestep mapping. self.partial_rewards_t[agent_id].append(self.t) @@ -935,14 +936,14 @@ def add_timestep( # mapping. self.global_t_to_local_t[agent_id].append(self.t) # Add data to `SingleAgentEpisode. - self.agent_episodes[agent_id].add_timestep( - observation=observation[agent_id], + self.agent_episodes[agent_id].add_env_step( + observation=observations[agent_id], action=agent_action, reward=agent_reward, - info=info.get(agent_id), - is_terminated=agent_is_terminated, - is_truncated=agent_is_truncated, - extra_model_output=agent_extra_model_output, + infos=infos.get(agent_id), + terminated=agent_is_terminated, + truncated=agent_is_truncated, + extra_model_outputs=agent_extra_model_output, ) # We have no buffered action. else: @@ -969,24 +970,24 @@ def add_timestep( self.global_t_to_local_t[agent_id].append(self.t) # The agent might have got a reward. # TODO (simon): Refactor to a function `record_rewards`. - if agent_id in reward: + if agent_id in rewards: # Add the reward to the one in the buffer. self.agent_buffers[agent_id]["rewards"].put_nowait( self.agent_buffers[agent_id]["rewards"].get_nowait() - + reward[agent_id] + + rewards[agent_id] ) # Add the reward to the partial rewards of this agent. - self.partial_rewards[agent_id].append(reward[agent_id]) + self.partial_rewards[agent_id].append(rewards[agent_id]) self.partial_rewards_t[agent_id].append(self.t) - self.agent_episodes[agent_id].add_initial_observation( - initial_observation=observation[agent_id], - initial_info=info.get(agent_id), + self.agent_episodes[agent_id].add_env_reset( + observation=observations[agent_id], + infos=infos.get(agent_id), ) # CASE 2: No observation, but action. # We have no observation, but we have an action. This must be an orphane # action and we need to buffer it. - elif agent_id not in observation and agent_id in action: + elif agent_id not in observations and agent_id in actions: # Maybe the agent got terminated. if agent_is_terminated or agent_is_truncated: # If this was indeed the agent's last step, we need to record it @@ -999,15 +1000,16 @@ def add_timestep( # agent_id list? # If the agent was terminated and no observation is provided, # take the last one. - self.agent_episodes[agent_id].add_timestep( + self.agent_episodes[agent_id].add_env_step( observation=self.agent_episodes[agent_id].observations[-1], - action=action[agent_id], - reward=0.0 if agent_id not in reward else reward[agent_id], - is_terminated=agent_is_terminated, - is_truncated=agent_is_truncated, - extra_model_output=None - if agent_id not in extra_model_output - else extra_model_output[agent_id], + action=actions[agent_id], + reward=0.0 if agent_id not in rewards else rewards[agent_id], + infos=self.agent_episodes[agent_id].infos[-1], + terminated=agent_is_terminated, + truncated=agent_is_truncated, + extra_model_outputs=None + if agent_id not in extra_model_outputs + else extra_model_outputs[agent_id], ) # Agent is still alive. else: @@ -1015,36 +1017,38 @@ def add_timestep( # original action timestep (global one). Right now the # `global_reward_t` might serve here. # Buffer the action. - self.agent_buffers[agent_id]["actions"].put_nowait(action[agent_id]) + self.agent_buffers[agent_id]["actions"].put_nowait( + actions[agent_id] + ) # Record the timestep for the action. self.global_actions_t[agent_id].append(self.t) # If available, buffer also reward. Note, if the agent is terminated # or truncated, we finish the `SingleAgentEpisode`. - if agent_id in reward: + if agent_id in rewards: # Add the reward to the existing one in the buffer. Note, the # default value is zero. # TODO (simon): Refactor to `record_rewards()`. self.agent_buffers[agent_id]["rewards"].put_nowait( self.agent_buffers[agent_id]["rewards"].get_nowait() - + reward[agent_id] + + rewards[agent_id] ) # Add to the global reward list. - self.partial_rewards[agent_id].append(reward[agent_id]) + self.partial_rewards[agent_id].append(rewards[agent_id]) # Add also to the global reward timestep mapping. self.partial_rewards_t[agent_id].append(self.t) # If the agent got any extra model outputs, buffer them, too. - if extra_model_output and agent_id in extra_model_output: + if extra_model_outputs and agent_id in extra_model_outputs: # Flush the default `None` from buffer. self.agent_buffers[agent_id]["extra_model_outputs"].get_nowait() # STore the extra model outputs into the buffer. self.agent_buffers[agent_id]["extra_model_outputs"].put_nowait( - extra_model_output[agent_id] + extra_model_outputs[agent_id] ) # CASE 3: No observation and no action. # We have neither observation nor action. Then, we could have `reward`, # `is_terminated` or `is_truncated` and should record it. - elif agent_id not in observation and agent_id not in action: + elif agent_id not in observations and agent_id not in actions: # The agent could be is_terminated if agent_is_terminated or agent_is_truncated: # If the agent has never stepped, we treat it as not being @@ -1076,11 +1080,11 @@ def add_timestep( # as it is initialized as a zero reward. agent_reward = self.agent_buffers[agent_id]["rewards"].get_nowait() # If a reward is received at this timestep record it. - if agent_id in reward: + if agent_id in rewards: # TODO (simon): Refactor to `record_rewards()`. - agent_reward += reward[agent_id] + agent_reward += rewards[agent_id] # Add to the global reward list. - self.partial_rewards[agent_id].append(reward[agent_id]) + self.partial_rewards[agent_id].append(rewards[agent_id]) # Add also to the global reward timestep mapping. self.partial_rewards_t[agent_id].append(self.t) @@ -1088,27 +1092,27 @@ def add_timestep( # it in the timestep mapping. self.global_t_to_local_t[agent_id].append(self.t) # Finish the agent's episode. - self.agent_episodes[agent_id].add_timestep( + self.agent_episodes[agent_id].add_env_step( observation=self.agent_episodes[agent_id].observations[-1], action=agent_action, reward=agent_reward, - info=info.get(agent_id), - is_terminated=agent_is_terminated, - is_truncated=agent_is_truncated, - extra_model_output=agent_extra_model_output, + infos=infos.get(agent_id), + terminated=agent_is_terminated, + truncated=agent_is_truncated, + extra_model_outputs=agent_extra_model_output, ) # The agent is still alive. else: # If the agent received an reward (triggered by actions of # other agents) we collect it and add it to the one in the # buffer. - if agent_id in reward: + if agent_id in rewards: self.agent_buffers[agent_id]["rewards"].put_nowait( self.agent_buffers[agent_id]["rewards"].get_nowait() - + reward[agent_id] + + rewards[agent_id] ) # Add to the global reward list. - self.partial_rewards[agent_id].append(reward[agent_id]) + self.partial_rewards[agent_id].append(rewards[agent_id]) # Add also to the global reward timestep mapping. self.partial_rewards_t[agent_id].append(self.t) # CASE 4: Observation and action. @@ -1128,31 +1132,31 @@ def add_timestep( self.global_t_to_local_t[agent_id].append(self.t) # Record the action to the global action timestep mapping. self.global_actions_t[agent_id].append(self.t) - if agent_id in reward: + if agent_id in rewards: # Also add to the global reward list. - self.partial_rewards[agent_id].append(reward[agent_id]) + self.partial_rewards[agent_id].append(rewards[agent_id]) # And add to the global reward timestep mapping. self.partial_rewards_t[agent_id].append(self.t) # Add timestep to `SingleAgentEpisode`. - self.agent_episodes[agent_id].add_timestep( - observation=observation[agent_id], - action=action[agent_id], - reward=0.0 if agent_id not in reward else reward[agent_id], - info=info.get(agent_id), - is_terminated=agent_is_terminated, - is_truncated=agent_is_truncated, - extra_model_output=None - if extra_model_output is None - else extra_model_output[agent_id], + self.agent_episodes[agent_id].add_env_step( + observation=observations[agent_id], + action=actions[agent_id], + reward=0.0 if agent_id not in rewards else rewards[agent_id], + infos=infos.get(agent_id), + terminated=agent_is_terminated, + truncated=agent_is_truncated, + extra_model_outputs=None + if extra_model_outputs is None + else extra_model_outputs[agent_id], ) @property def is_done(self): """Whether the episode is actually done (terminated or truncated). - A done episode cannot be continued via `self.add_timestep()` or being + A done episode cannot be continued via `self.add_env_step()` or being concatenated on its right-side with another episode chunk or being - succeeded via `self.create_successor()`. + succeeded via `self.cut()`. Note that in a multi-agent environment this does not necessarily correspond to single agents having terminated or being truncated. @@ -1178,23 +1182,15 @@ def is_done(self): # TODO (sven, simon): We are taking over dead agents to the successor # is this intended or should we better check during concatenation, if # the union of agents from both episodes is included? Next issue. - def create_successor(self) -> "MultiAgentEpisode": - """Restarts an ongoing episode from its last observation. - - Note, this method is used so far specifically for the case of - `batch_mode="truncated_episodes"` to ensure that episodes are - immutable inside the `EnvRunner` when truncated and passed over - to postprocessing. - - The newly created `MultiAgentEpisode` contains the same id, and - starts at the timestep where it's predecessor stopped in the last - rollout. Last observations, infos, rewards, etc. are carried over - from the predecessor. This also helps to not carry stale data that - had been collected in the last rollout when rolling out the new - policy in the next iteration (rollout). - - Returns: A MultiAgentEpisode starting at the timepoint where - its predecessor stopped. + def cut(self) -> "MultiAgentEpisode": + """Returns a successor episode chunk (of len=0) continuing from this Episode. + + The successor will have the same ID as `self` and starts at the timestep where + it's predecessor `self` left off. The current observations and infos + are carried over from the predecessor as initial observations/infos. + + Returns: A MultiAgentEpisode starting at the timestep where its predecessor + stopped. """ assert not self.is_done @@ -1205,8 +1201,8 @@ def create_successor(self) -> "MultiAgentEpisode": agent_id: agent_eps.id_ for agent_id, agent_eps in self.agent_episodes.items() }, - is_terminated=self.is_terminated, - is_truncated=self.is_truncated, + terminateds=self.is_terminated, + truncateds=self.is_truncated, t_started=self.t, ) @@ -1215,7 +1211,7 @@ def create_successor(self) -> "MultiAgentEpisode": # all agents that are still alive. if not agent_eps.is_done and agent_eps.observations: # Build a successor for each agent that is not done, yet. - successor.agent_episodes[agent_id] = agent_eps.create_successor() + successor.agent_episodes[agent_id] = agent_eps.cut() # Record the initial observation in the global timestep mapping. successor.global_t_to_local_t[agent_id] = _IndexMapping( [self.global_t_to_local_t[agent_id][-1]] @@ -1225,8 +1221,8 @@ def create_successor(self) -> "MultiAgentEpisode": else: successor.agent_episodes[agent_id] = SingleAgentEpisode( id_=agent_eps.id_, - is_terminated=agent_eps.is_terminated, - is_truncated=agent_eps.is_truncated, + terminated=agent_eps.is_terminated, + truncated=agent_eps.is_truncated, ) successor.global_t_to_local_t[agent_id] = _IndexMapping() @@ -1324,7 +1320,7 @@ def to_sample_batch(self) -> MultiAgentBatch: # Note, only agents that have stepped are included into the batch. return MultiAgentBatch( policy_batches={ - agent_id: agent_eps.to_sample_batch() + agent_id: agent_eps.get_sample_batch() for agent_id, agent_eps in self.agent_episodes.items() if agent_eps.t - agent_eps.t_started > 0 }, @@ -1347,7 +1343,7 @@ def get_return(self, consider_buffer=False) -> float: sum(len(agent_map) for agent_map in self.global_t_to_local_t.values()) > 0 ), ( "ERROR: Cannot determine return of episode that hasn't started, yet!" - "Call `MultiAgentEpisode.add_initial_observation(initial_observation=)` " + "Call `MultiAgentEpisode.add_env_reset(observations=)` " "first (after which `get_return(MultiAgentEpisode)` will be 0)." ) env_return = sum( @@ -1489,47 +1485,14 @@ def _generate_single_agent_episode( actions: Optional[List[MultiAgentDict]] = None, rewards: Optional[List[MultiAgentDict]] = None, infos: Optional[List[MultiAgentDict]] = None, - is_terminateds: Union[MultiAgentDict, bool] = False, - is_truncateds: Union[MultiAgentDict, bool] = False, - extra_model_outputs: Optional[MultiAgentDict] = None, + terminateds: Union[MultiAgentDict, bool] = False, + truncateds: Union[MultiAgentDict, bool] = False, + extra_model_outputs: Optional[List[MultiAgentDict]] = None, ) -> SingleAgentEpisode: - """Generates a `SingleAgentEpisode` from multi-agent data. + """Generates a SingleAgentEpisode from multi-agent data. Note, if no data is provided an empty `SingleAgentEpiosde` - will be returned that starts at `SIngleAgentEpisode.t_started=0`. - - Args: - agent_id: String, idnetifying the agent for which the data should - be extracted. - agent_episode_ids: Optional. A dictionary mapping agents to - corresponding episode ids. If `None` the `SingleAgentEpisode` - creates a hexadecimal code. - observations: Optional. A list of dictionaries, each mapping - from agent ids to observations. When data is provided - it should be complete, i.e. observations, actions, rewards, - etc. should be provided. - actions: Optional. A list of dictionaries, each mapping - from agent ids to actions. When data is provided - it should be complete, i.e. observations, actions, rewards, - etc. should be provided. - rewards: Optional. A list of dictionaries, each mapping - from agent ids to rewards. When data is provided - it should be complete, i.e. observations, actions, rewards, - etc. should be provided. - infos: Optional. A list of dictionaries, each mapping - from agent ids to infos. When data is provided - it should be complete, i.e. observations, actions, rewards, - etc. should be provided. - extra_model_outputs: Optional. A list of agent mappings for every - timestep. Each of these dictionaries maps an agent to its - corresponding `extra_model_outputs`, which a re specific model - outputs needed by the algorithm used (e.g. `vf_preds` and - `action_logp` for PPO). f data is provided it should be complete - (i.e. observations, actions, rewards, is_terminated, is_truncated, - and all necessary `extra_model_outputs`). - - Returns: An instance of `SingleAgentEpisode` containing the agent's - extracted episode data. + will be returned that starts at `SingleAgentEpisode.t_started=0`. """ # If an episode id for an agent episode was provided assign it. @@ -1569,7 +1532,7 @@ def _generate_single_agent_episode( ) # Like observations, infos start at timestep `t=0`, so we do not need to # shift or start later when using the global timestep mapping. But we - # need to use tha timestep carriage in case the starting timestep is + # need to use the timestep carriage in case the starting timestep is # different from the length of observations-after-initialization. agent_infos = ( None @@ -1579,7 +1542,7 @@ def _generate_single_agent_episode( ) ) - agent_extra_model_outputs = ( + _agent_extra_model_outputs = ( None if extra_model_outputs is None else self._get_single_agent_data( @@ -1588,35 +1551,15 @@ def _generate_single_agent_episode( use_global_t_to_local_t=False, ) ) + # Convert `extra_model_outputs` for this agent from list of dicts to dict + # of lists. + agent_extra_model_outputs = defaultdict(list) + for _model_out in _agent_extra_model_outputs: + for key, val in _model_out.items(): + agent_extra_model_outputs[key].append(val) - agent_is_terminated = ( - [False] - if is_terminateds is None - else self._get_single_agent_data( - agent_id, is_terminateds, use_global_t_to_local_t=False - ) - # else self._get_single_agent_data( - # agent_id, is_terminateds, start_index=1, shift=-1 - # ) - ) - # If a list the list could be empty, if the agent never stepped. - agent_is_terminated = ( - False if not agent_is_terminated else agent_is_terminated[-1] - ) - - agent_is_truncated = ( - [False] - if is_truncateds is None - else self._get_single_agent_data( - agent_id, - is_truncateds, - use_global_t_to_local_t=False, - ) - ) - # If a list the list could be empty, if the agent never stepped. - agent_is_truncated = ( - False if not agent_is_truncated else agent_is_truncated[-1] - ) + agent_is_terminated = terminateds.get(agent_id, False) + agent_is_truncated = truncateds.get(agent_id, False) # If there are as many actions as observations we have to buffer. if ( @@ -1626,14 +1569,18 @@ def _generate_single_agent_episode( ): # Assert then that the other data is in order. if agent_extra_model_outputs: - assert len(agent_extra_model_outputs) == len( - agent_actions - ), f"Agent {agent_id} has not as many extra model outputs as " - "actions." + assert all( + len(v) == len(agent_actions) + for v in agent_extra_model_outputs.values() + ), ( + f"Agent {agent_id} doesn't have the same number of " + "`extra_model_outputs` as it has actions " + f"({len(agent_actions)})." + ) # Put the last extra model outputs into the buffer. self.agent_buffers[agent_id]["extra_model_outputs"].get_nowait() self.agent_buffers[agent_id]["extra_model_outputs"].put_nowait( - agent_extra_model_outputs.pop() + {k: v.pop() for k, v in agent_extra_model_outputs.items()} ) # Put the last action into the buffer. @@ -1643,8 +1590,7 @@ def _generate_single_agent_episode( # `_generate_partial_rewards` method and can be done where # the global timestep and global action timestep # mappings are created (__init__). - # We have to take care of partial rewards when generating the - # agent rewards: + # We have to take care of partial rewards when generating the agent rewards: # 1. Rewards between different observations -> added up and # assigned to next observation. # 2. Rewards after the last observation -> get buffered and added up @@ -1671,7 +1617,6 @@ def _generate_single_agent_episode( if (t + 1) in self.global_t_to_local_t[agent_id][1:]: agent_rewards.append(agent_reward) agent_reward = 0.0 - continue # If the agent reward is not zero, we must have rewards that came # after the last observation. Then we buffer this reward. @@ -1688,8 +1633,8 @@ def _generate_single_agent_episode( actions=agent_actions, rewards=agent_rewards, infos=agent_infos, - is_terminated=agent_is_terminated, - is_truncated=agent_is_truncated, + terminated=agent_is_terminated, + truncated=agent_is_truncated, extra_model_outputs=agent_extra_model_outputs, ) # Otherwise return empty `SingleAgentEpisode`. @@ -1952,7 +1897,7 @@ def __len__(self): sum(len(agent_map) for agent_map in self.global_t_to_local_t.values()) > 0 ), ( "ERROR: Cannot determine length of episode that hasn't started, yet!" - "Call `MultiAgentEpisode.add_initial_observation(initial_observation=)` " + "Call `MultiAgentEpisode.add_env_reset(observations=)` " "first (after which `len(MultiAgentEpisode)` will be 0)." ) return self.t - self.t_started diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index bfee8ba7482b..b2c3f137701c 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -9,6 +9,7 @@ from ray.rllib.core.models.base import STATE_IN, STATE_OUT from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.env.utils import _gym_env_creator from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch @@ -22,9 +23,6 @@ if TYPE_CHECKING: from ray.rllib.algorithms.algorithm_config import AlgorithmConfig - # TODO (sven): This gives a tricky circular import that goes - # deep into the library. We have to see, where to dissolve it. - from ray.rllib.env.single_agent_episode import SingleAgentEpisode _, tf, _ = try_import_tf() torch, nn = try_import_torch() @@ -74,15 +72,20 @@ def __init__(self, config: "AlgorithmConfig", **kwargs): # Create our own instance of the (single-agent) `RLModule` (which # the needs to be weight-synched) each iteration. - module_spec: SingleAgentRLModuleSpec = self.config.get_default_rl_module_spec() - module_spec.observation_space = self.env.envs[0].observation_space - # TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should - # actually hold the spaces for a single env, but for boxes the - # shape is (1, 1) which brings a problem with the action dists. - # shape=(1,) is expected. - module_spec.action_space = self.env.envs[0].action_space - module_spec.model_config_dict = self.config.model - self.module: RLModule = module_spec.build() + try: + module_spec: SingleAgentRLModuleSpec = ( + self.config.get_default_rl_module_spec() + ) + module_spec.observation_space = self.env.envs[0].observation_space + # TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should + # actually hold the spaces for a single env, but for boxes the + # shape is (1, 1) which brings a problem with the action dists. + # shape=(1,) is expected. + module_spec.action_space = self.env.envs[0].action_space + module_spec.model_config_dict = self.config.model + self.module: RLModule = module_spec.build() + except NotImplementedError: + self.module = None # This should be the default. self._needs_initial_reset: bool = True @@ -95,6 +98,11 @@ def __init__(self, config: "AlgorithmConfig", **kwargs): self._ts_since_last_metrics: int = 0 self._weights_seq_no: int = 0 + # TODO (sven): This is a temporary solution. STATE_OUTs + # will be resolved entirely as `extra_model_outputs` and + # not be stored separately inside Episodes. + self._states = [None for _ in range(self.num_envs)] + @override(EnvRunner) def sample( self, @@ -106,6 +114,7 @@ def sample( with_render_data: bool = False, ) -> List["SingleAgentEpisode"]: """Runs and returns a sample (n timesteps or m episodes) on the env(s).""" + assert not (num_timesteps is not None and num_episodes is not None) # If not execution details are provided, use the config. if num_timesteps is None and num_episodes is None: @@ -176,13 +185,11 @@ def _sample_timesteps( for i in range(self.num_envs): # TODO (sven): Maybe move this into connector pipeline # (even if automated). - self._episodes[i].add_initial_observation( - initial_observation=obs[i], - initial_info=infos[i], - # TODO (simon): Check, if this works for the default - # stateful encoders. - initial_state={k: s[i] for k, s in states.items()}, + self._episodes[i].add_env_reset( + observation=obs[i], + infos=infos[i], ) + self._states[i] = {k: s[i] for k, s in states.items()} # Do not reset envs, but instead continue in already started episodes. else: # Pick up stored observations and states from previous timesteps. @@ -193,8 +200,8 @@ def _sample_timesteps( states = { k: np.stack( [ - initial_states[k][i] if eps.states is None else eps.states[k] - for i, eps in enumerate(self._episodes) + initial_states[k][i] if state is None else state[k] + for i, state in enumerate(self._states) ] ) for k in initial_states.keys() @@ -207,7 +214,8 @@ def _sample_timesteps( # Act randomly. if random_actions: actions = self.env.action_space.sample() - # TODO (simon): Add action_logp for smapled actions. + action_logp = np.zeros(shape=(actions.shape[0],)) + fwd_out = {} # Compute an action using the RLModule. else: # Note, RLModule `forward()` methods expect `NestedDict`s. @@ -228,6 +236,8 @@ def _sample_timesteps( else: fwd_out = self.module.forward_inference(batch) + # TODO (sven): Will be completely replaced by connector logic in + # upcoming PR. actions, action_logp = self._sample_actions_if_necessary( fwd_out, explore ) @@ -260,54 +270,61 @@ def _sample_timesteps( if terminateds[i] or truncateds[i]: # Finish the episode with the actual terminal observation stored in # the info dict. - self._episodes[i].add_timestep( + self._episodes[i].add_env_step( # Gym vector env provides the `"final_observation"`. infos[i]["final_observation"], actions[i], rewards[i], - info=infos[i]["final_info"], - state=s, - is_terminated=terminateds[i], - is_truncated=truncateds[i], - extra_model_output=extra_model_output, + infos=infos[i]["final_info"], + terminated=terminateds[i], + truncated=truncateds[i], + extra_model_outputs=extra_model_output, ) + self._states[i] = s # Reset h-states to nthe model's intiial ones b/c we are starting a # new episode. - for k, v in self.module.get_initial_state().items(): - states[k][i] = convert_to_numpy(v) + if hasattr(self.module, "get_initial_state"): + for k, v in self.module.get_initial_state().items(): + states[k][i] = convert_to_numpy(v) - done_episodes_to_return.append(self._episodes[i]) - # Create a new episode object. + done_episodes_to_return.append(self._episodes[i].finalize()) + # Create a new episode object with already the reset data in it. self._episodes[i] = SingleAgentEpisode( - observations=[obs[i]], infos=[infos[i]], states=s + observations=[obs[i]], infos=[infos[i]] ) + self._states[i] = s else: - self._episodes[i].add_timestep( + self._episodes[i].add_env_step( obs[i], actions[i], rewards[i], - info=infos[i], - state=s, - extra_model_output=extra_model_output, + infos=infos[i], + extra_model_outputs=extra_model_output, ) + self._states[i] = s # Return done episodes ... self._done_episodes_for_metrics.extend(done_episodes_to_return) + # Also, make sure, we return a copy and start new chunks so that callers + # of this function do not alter the ongoing and returned Episode objects. + new_episodes = [eps.cut() for eps in self._episodes] + # ... and all ongoing episode chunks. # Initialized episodes do not have recorded any step and lack # `extra_model_outputs`. - ongoing_episodes = [episode for episode in self._episodes if episode.t > 0] - # Also, make sure, we return a copy and start new chunks so that callers - # of this function do not alter the ongoing and returned Episode objects. - self._episodes = [eps.create_successor() for eps in self._episodes] - for eps in ongoing_episodes: + ongoing_episodes_to_return = [ + episode.finalize() for episode in self._episodes if episode.t > 0 + ] + for eps in ongoing_episodes_to_return: self._ongoing_episodes_for_metrics[eps.id_].append(eps) # Record last metrics collection. self._ts_since_last_metrics += ts - return done_episodes_to_return + ongoing_episodes + self._episodes = new_episodes + + return done_episodes_to_return + ongoing_episodes_to_return def _sample_episodes( self, @@ -320,42 +337,50 @@ def _sample_episodes( See docstring of `self.sample()` for more details. """ - # TODO (sven): This gives a tricky circular import that goes # deep into the library. We have to see, where to dissolve it. from ray.rllib.env.single_agent_episode import SingleAgentEpisode + # If user calls sample(num_timesteps=..) after this, we must reset again + # at the beginning. + self._needs_initial_reset = True + done_episodes_to_return: List["SingleAgentEpisode"] = [] obs, infos = self.env.reset() episodes = [SingleAgentEpisode() for _ in range(self.num_envs)] - # Multiply states n times according to our vector env batch size (num_envs). - states = tree.map_structure( - lambda s: np.repeat(s, self.num_envs, axis=0), - self.module.get_initial_state(), - ) + # Get initial states for all 'batch_size_B` rows in the forward batch, + # i.e. for all vector sub_envs. + if hasattr(self.module, "get_initial_state"): + states = tree.map_structure( + lambda s: np.repeat(s, self.num_envs, axis=0), + self.module.get_initial_state(), + ) + else: + states = {} render_images = [None] * self.num_envs if with_render_data: render_images = [e.render() for e in self.env.envs] for i in range(self.num_envs): - # TODO (sven): Maybe move this into connector pipeline - # (even if automated). - episodes[i].add_initial_observation( - initial_observation=obs[i], - initial_info=infos[i], - initial_state={k: s[i] for k, s in states.items()}, - initial_render_image=render_images[i], + episodes[i].add_env_reset( + observation=obs[i], + infos=infos[i], + render_image=render_images[i], ) eps = 0 while eps < num_episodes: if random_actions: actions = self.env.action_space.sample() + action_logp = np.zeros(shape=(actions.shape[0])) + fwd_out = {} else: batch = { + # TODO (sven): This will move entirely into connector logic in + # upcoming PR. STATE_IN: tree.map_structure( lambda s: self._convert_from_numpy(s), states ), @@ -368,17 +393,18 @@ def _sample_episodes( else: fwd_out = self.module.forward_inference(batch) + # TODO (sven): This will move entirely into connector logic in upcoming + # PR. actions, action_logp = self._sample_actions_if_necessary( fwd_out, explore ) fwd_out = convert_to_numpy(fwd_out) + # TODO (sven): This will move entirely into connector logic in upcoming + # PR. if STATE_OUT in fwd_out: states = convert_to_numpy(fwd_out[STATE_OUT]) - # states = tree.map_structure( - # lambda s: s.numpy(), fwd_out[STATE_OUT] - # ) obs, rewards, terminateds, truncateds, infos = self.env.step(actions) if with_render_data: @@ -387,28 +413,27 @@ def _sample_episodes( for i in range(self.num_envs): # Extract info and state for vector sub_env. # info = {k: v[i] for k, v in infos.items()} - s = {k: s[i] for k, s in states.items()} # The last entry in self.observations[i] is already the reset # obs of the new episode. extra_model_output = {} for k, v in fwd_out.items(): if SampleBatch.ACTIONS not in k: extra_model_output[k] = v[i] - # TODO (simon, sven): Some algos do not have logps. + # TODO (sven): This will move entirely into connector logic in upcoming + # PR. extra_model_output[SampleBatch.ACTION_LOGP] = action_logp[i] if terminateds[i] or truncateds[i]: eps += 1 - episodes[i].add_timestep( + episodes[i].add_env_step( infos[i]["final_observation"], actions[i], rewards[i], - info=infos[i]["final_info"], - state=s, - is_terminated=terminateds[i], - is_truncated=truncateds[i], - extra_model_output=extra_model_output, + infos=infos[i]["final_info"], + terminated=terminateds[i], + truncated=truncateds[i], + extra_model_outputs=extra_model_output, ) done_episodes_to_return.append(episodes[i]) @@ -418,38 +443,33 @@ def _sample_episodes( if eps == num_episodes: break - # Reset h-states to the model's initial ones b/c we are starting - # a new episode. - for k, v in self.module.get_initial_state().items(): - states[k][i] = (convert_to_numpy(v),) + # TODO (sven): This will move entirely into connector logic in + # upcoming PR. + if hasattr(self.module, "get_initial_state"): + for k, v in self.module.get_initial_state().items(): + states[k][i] = (convert_to_numpy(v),) # Create a new episode object. episodes[i] = SingleAgentEpisode( observations=[obs[i]], infos=[infos[i]], - states=s, render_images=None if render_images[i] is None else [render_images[i]], ) else: - episodes[i].add_timestep( + episodes[i].add_env_step( obs[i], actions[i], rewards[i], - info=infos[i], - state=s, + infos=infos[i], render_image=render_images[i], - extra_model_output=extra_model_output, + extra_model_outputs=extra_model_output, ) self._done_episodes_for_metrics.extend(done_episodes_to_return) self._ts_since_last_metrics += sum(len(eps) for eps in done_episodes_to_return) - # If user calls sample(num_timesteps=..) after this, we must reset again - # at the beginning. - self._needs_initial_reset = True - # Initialized episodes have to be removed as they lack `extra_model_outputs`. return [episode for episode in done_episodes_to_return if episode.t > 0] diff --git a/rllib/env/single_agent_episode.py b/rllib/env/single_agent_episode.py index 25c7a02bb046..8071f0d33520 100644 --- a/rllib/env/single_agent_episode.py +++ b/rllib/env/single_agent_episode.py @@ -1,127 +1,301 @@ +import functools +from collections import defaultdict import numpy as np import uuid +import gymnasium as gym from gymnasium.core import ActType, ObsType -from typing import Any, Dict, List, Optional, SupportsFloat +from typing import Any, Dict, List, Optional, SupportsFloat, Union from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.env.utils import BufferWithInfiniteLookback class SingleAgentEpisode: + """A class representing RL environment episodes for individual agents. + + SingleAgentEpisode stores observations, info dicts, actions, rewards, and all + module outputs (e.g. state outs, action logp, etc..) for an individual agent within + some single-agent or multi-agent environment. + The two main APIs to add data to an ongoing episode are the `add_env_reset()` + and `add_env_step()` methods, which should be called passing the outputs of the + respective gym.Env API calls: `env.reset()` and `env.step()`. + + A SingleAgentEpisode might also only represent a chunk of an episode, which is + useful for cases, in which partial (non-complete episode) sampling is performed + and collected episode data has to be returned before the actual gym.Env episode has + finished (see `SingleAgentEpisode.cut()`). In order to still maintain visibility + onto past experiences within such a "cut" episode, SingleAgentEpisode instances + can have a "lookback buffer" of n timesteps at their beginning (left side), which + solely exists for the purpose of compiling extra data (e.g. "prev. reward"), but + is not considered part of the finished/packaged episode (b/c the data in the + lookback buffer is already part of a previous episode chunk). + + Powerful getter methods, such as `get_observations()` help collect different types + of data from the episode at individual time indices or time ranges, including the + "lookback buffer" range described above. For example, to extract the last 4 rewards + of an ongoing episode, one can call `self.get_rewards(slice(-4, None))` or + `self.rewards[-4:]`. This would work, even if the ongoing SingleAgentEpisode is + a continuation chunk from a much earlier started episode, as long as it has a + lookback buffer size of sufficient size. + + Examples: + + .. testcode:: + + import gymnasium as gym + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + # Construct a new episode (without any data in it yet). + episode = SingleAgentEpisode() + assert len(episode) == 0 + + # Fill the episode with some data (10 timesteps). + env = gym.make("CartPole-v1") + obs, infos = env.reset() + episode.add_env_reset(obs, infos) + + # Even with the initial obs/infos, the episode is still considered len=0. + assert len(episode) == 0 + for _ in range(10): + action = env.action_space.sample() + obs, reward, term, trunc, infos = env.step(action) + episode.add_env_step( + observation=obs, + action=action, + reward=reward, + terminated=term, + truncated=trunc, + infos=infos, + ) + assert len(episode) == 10 + + # We can now access information from the episode via the getters. + + # Get the last 3 rewards (in a batch of size 3). + episode.get_rewards(slice(-3, None)) # same as `episode.rewards[-3:]` + + # Get the most recent action (single item, not batched). + # This works regardless of the action space or whether the episode has + # been finalized or not (see below). + episode.get_actions(-1) # same as episode.actions[-1] + + # Looking back from ts=1, get the previous 4 rewards AND fill with 0.0 + # in case we go over the beginning (ts=0). So we would expect + # [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received + # reward in the episode: + episode.get_rewards(slice(-4, 0), neg_indices_left_of_zero=True, fill=0.0) + + # Note the use of fill=0.0 here (fill everything that's out of range with this + # value) AND the argument `neg_indices_left_of_zero=True`, which interprets + # negative indices as being left of ts=0 (e.g. -1 being the timestep before + # ts=0). + + # Assuming we had a complex action space (nested gym.spaces.Dict) with one or + # more elements being Discrete or MultiDiscrete spaces: + # 1) The `fill=...` argument would still work, filling all spaces (Boxes, + # Discrete) with that provided value. + # 2) Setting the flag `one_hot_discrete=True` would convert those discrete + # sub-components automatically into one-hot (or multi-one-hot) tensors. + # This simplifies the task of having to provide the previous 4 (nested and + # partially discrete/multi-discrete) actions for each timestep within a training + # batch, thereby filling timesteps before the episode started with 0.0s and + # one-hot'ing the discrete/multi-discrete components in these actions: + episode = SingleAgentEpisode(action_space=gym.spaces.Dict({ + "a": gym.spaces.Discrete(3), + "b": gym.spaces.MultiDiscrete([2, 3]), + "c": gym.spaces.Box(-1.0, 1.0, (2,)), + })) + + # ... fill episode with data ... + episode.add_env_reset(observation=0) + # ... from a few steps. + episode.add_env_step( + observation=1, + action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)}, + reward=1.0, + ) + + # In your connector + prev_4_a = [] + # Note here that len(episode) does NOT include the lookback buffer. + for ts in range(len(episode)): + prev_4_a.append( + episode.get_actions( + indices=slice(ts - 4, ts), + # Make sure negative indices are interpreted as + # "into lookback buffer" + neg_indices_left_of_zero=True, + # Zero-out everything even further before the lookback buffer. + fill=0.0, + # Take care of discrete components (get ready as NN input). + one_hot_discrete=True, + ) + ) + + # Finally, convert from list of batch items to a struct (same as action space) + # of batched (numpy) arrays, in which all leafs have B==len(prev_4_a). + from ray.rllib.utils.spaces.space_utils import batch + + prev_4_actions_col = batch(prev_4_a) + """ + def __init__( self, id_: Optional[str] = None, *, observations: List[ObsType] = None, + observation_space: Optional[gym.Space] = None, actions: List[ActType] = None, + action_space: Optional[gym.Space] = None, rewards: List[SupportsFloat] = None, infos: List[Dict] = None, - states=None, - t_started: Optional[int] = None, - is_terminated: bool = False, - is_truncated: bool = False, - render_images: Optional[List[np.ndarray]] = None, + terminated: bool = False, + truncated: bool = False, extra_model_outputs: Optional[Dict[str, Any]] = None, + render_images: Optional[List[np.ndarray]] = None, + t_started: Optional[int] = None, + len_lookback_buffer: Optional[int] = 0, ) -> "SingleAgentEpisode": - """Initializes a `SingleAgentEpisode` instance. + """Initializes a SingleAgentEpisode instance. - This constructor can be called with or without sampled data. Note - that if data is provided the episode will start at timestep - `t_started = len(observations) - 1` (the initial observation is not - counted). If the episode should start at `t_started = 0` (e.g. - because the instance should simply store episode data) this has to - be provided in the `t_started` parameter of the constructor. + This constructor can be called with or without already sampled data, part of + which might then go into the lookback buffer. Args: id_: Optional. Unique identifier for this episode. If no id is provided the constructor generates a hexadecimal code for the id. observations: Optional. A list of observations from a rollout. If data is provided it should be complete (i.e. observations, actions, - rewards, is_terminated, is_truncated, and all necessary + rewards, terminated, truncated, and all necessary `extra_model_outputs`). The length of the `observations` defines the default starting value. See the parameter `t_started`. actions: Optional. A list of actions from a rollout. If data is provided it should be complete (i.e. observations, actions, - rewards, is_terminated, is_truncated, and all necessary + rewards, terminated, truncated, and all necessary `extra_model_outputs`). rewards: Optional. A list of rewards from a rollout. If data is provided it should be complete (i.e. observations, actions, - rewards, is_terminated, is_truncated, and all necessary + rewards, terminated, truncated, and all necessary `extra_model_outputs`). infos: Optional. A list of infos from a rollout. If data is provided it should be complete (i.e. observations, actions, - rewards, is_terminated, is_truncated, and all necessary + rewards, terminated, truncated, and all necessary `extra_model_outputs`). states: Optional. The hidden model states from a rollout. If data is provided it should be complete (i.e. observations, actions, - rewards, is_terminated, is_truncated, and all necessary + rewards, terminated, truncated, and all necessary `extra_model_outputs`). States are only avasilable if a stateful model (`RLModule`) is used. - t_started: Optional. The starting timestep of the episode. The default - is zero. If data is provided, the starting point is from the last - observation onwards (i.e. `t_started = len(observations) - 1). If - this parameter is provided the episode starts at the provided value. - is_terminated: Optional. A boolean indicating, if the episode is already + terminated: Optional. A boolean indicating, if the episode is already terminated. Note, this parameter is only needed, if episode data is provided in the constructor. The default is `False`. - is_truncated: Optional. A boolean indicating, if the episode was + truncated: Optional. A boolean indicating, if the episode was truncated. Note, this parameter is only needed, if episode data is provided in the constructor. The default is `False`. - render_images: Optional. A list of RGB uint8 images from rendering - the environment. extra_model_outputs: Optional. A list of dictionaries containing specific model outputs for the algorithm used (e.g. `vf_preds` and `action_logp` for PPO) from a rollout. If data is provided it should be complete - (i.e. observations, actions, rewards, is_terminated, is_truncated, + (i.e. observations, actions, rewards, terminated, truncated, and all necessary `extra_model_outputs`). + render_images: Optional. A list of RGB uint8 images from rendering + the environment. + t_started: Optional. The starting timestep of the episode. The default + is zero. If data is provided, the starting point is from the last + observation onwards (i.e. `t_started = len(observations) - 1). If + this parameter is provided the episode starts at the provided value. + len_lookback_buffer: The size of an optional lookback buffer to keep in + front of this Episode. If >0, will interpret the first + `len_lookback_buffer` items in each data as NOT part of this actual + episode chunk, but instead serve as historic data that may be viewed. + If None, will interpret all provided data in constructor as part of the + lookback buffer. """ self.id_ = id_ or uuid.uuid4().hex + + # Lookback buffer length is provided. + if len_lookback_buffer is not None: + self._len_lookback_buffer = len_lookback_buffer + # Lookback buffer length is not provided. Interpret already given data as + # lookback buffer. + else: + self._len_lookback_buffer = len(rewards or []) + + infos = infos or [{} for _ in range(len(observations or []))] + # Observations: t0 (initial obs) to T. - self.observations = [] if observations is None else observations + self.observation_space = observation_space + self.observations = BufferWithInfiniteLookback( + data=observations, + lookback=self._len_lookback_buffer, + space=observation_space, + ) # Actions: t1 to T. - self.actions = [] if actions is None else actions + self.action_space = action_space + self.actions = BufferWithInfiniteLookback( + data=actions, + lookback=self._len_lookback_buffer, + space=action_space, + ) # Rewards: t1 to T. - self.rewards = [] if rewards is None else rewards + self.rewards = BufferWithInfiniteLookback( + data=rewards, + lookback=self._len_lookback_buffer, + space=gym.spaces.Box(float("-inf"), float("inf"), (), np.float32), + ) # Infos: t0 (initial info) to T. - if infos is None: - self.infos = [{} for _ in range(len(self.observations))] - else: - self.infos = infos - # h-states: t0 (in case this episode is a continuation chunk, we need to know - # about the initial h) to T. - # TODO (simon): Create as list not as singleton. - self.states = states - # The global last timestep of the episode and the timesteps when this chunk - # started. - # TODO (simon): Check again what are the consequences of this decision for - # the methods of this class. For example the `validate()` method or - # `create_successor`. Write a test. - # Note, the case `t_started > len(observations) - 1` can occur, if a user - # wants to have an episode that is ongoing but does not want to carry the - # stale data from the last rollout in it. - self.t = self.t_started = ( - t_started if t_started is not None else max(len(self.observations) - 1, 0) + self.infos = BufferWithInfiniteLookback( + data=infos, + lookback=self._len_lookback_buffer, ) - if self.t_started < len(self.observations) - 1: - self.t = len(self.observations) - 1 # obs[-1] is the final observation in the episode. - self.is_terminated = is_terminated + self.is_terminated = terminated # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more # observations in following chunks for this episode). - self.is_truncated = is_truncated + self.is_truncated = truncated + # Extra model outputs, e.g. `action_dist_input` needed in the batch. + self.extra_model_outputs = defaultdict( + functools.partial( + BufferWithInfiniteLookback, + lookback=self._len_lookback_buffer, + ), + ) + for k, v in (extra_model_outputs or {}).items(): + if isinstance(v, BufferWithInfiniteLookback): + self.extra_model_outputs[k] = v + else: + self.extra_model_outputs[k].data = v + # RGB uint8 images from rendering the env; the images include the corresponding # rewards. assert render_images is None or observations is not None - self.render_images = [] if render_images is None else render_images - # Extra model outputs, e.g. `action_dist_input` needed in the batch. - self.extra_model_outputs = ( - {} if extra_model_outputs is None else extra_model_outputs + self.render_images = render_images or [] + + # The global last timestep of the episode and the timesteps when this chunk + # started (excluding a possible lookback buffer). + self.t_started = t_started or 0 + + self.t = ( + (len(rewards) if rewards is not None else 0) + - self._len_lookback_buffer + + self.t_started ) + # Validate the episode data thus far. + self.validate() + def concat_episode(self, episode_chunk: "SingleAgentEpisode"): """Adds the given `episode_chunk` to the right side of self. + In order for this to work, both chunks (`self` and `episode_chunk`) must fit + together. This is checked by the IDs (must be identical), the time step counters + (`self.t` must be the same as `episode_chunk.t_started`), as well as the + observations/infos at the concatenation boundaries (`self.observations[-1]` + must match `episode_chunk.observations[0]`). Also, `self.is_done` must not be + True, meaning `self.is_terminated` and `self.is_truncated` are both False. + Args: episode_chunk: Another `SingleAgentEpisode` to be concatenated. @@ -129,7 +303,7 @@ def concat_episode(self, episode_chunk: "SingleAgentEpisode"): from both episodes. """ assert episode_chunk.id_ == self.id_ - assert not self.is_done + assert not self.is_done and not self.is_finalized # Make sure the timesteps match. assert self.t == episode_chunk.t_started @@ -144,41 +318,42 @@ def concat_episode(self, episode_chunk: "SingleAgentEpisode"): # Extend ourselves. In case, episode_chunk is already terminated (and numpyfied) # we need to convert to lists (as we are ourselves still filling up lists). - self.observations.extend(list(episode_chunk.observations)) - self.actions.extend(list(episode_chunk.actions)) - self.rewards.extend(list(episode_chunk.rewards)) - self.infos.extend(list(episode_chunk.infos)) + self.observations.extend(episode_chunk.get_observations()) + self.actions.extend(episode_chunk.get_actions()) + self.rewards.extend(episode_chunk.get_rewards()) + self.infos.extend(episode_chunk.get_infos()) self.t = episode_chunk.t - self.states = episode_chunk.states if episode_chunk.is_terminated: self.is_terminated = True elif episode_chunk.is_truncated: self.is_truncated = True - for k, v in episode_chunk.extra_model_outputs.items(): - self.extra_model_outputs[k].extend(list(v)) + for model_out_key in episode_chunk.extra_model_outputs.keys(): + self.extra_model_outputs[model_out_key].extend( + episode_chunk.get_extra_model_outputs(model_out_key) + ) # Validate. self.validate() - def add_initial_observation( + def add_env_reset( self, + observation: ObsType, + infos: Optional[Dict] = None, *, - initial_observation: ObsType, - initial_info: Optional[Dict] = None, - initial_state=None, - initial_render_image: Optional[np.ndarray] = None, + render_image: Optional[np.ndarray] = None, ) -> None: - """Adds the initial data to the episode. + """Adds the initial data (after an `env.reset()`) to the episode. + + This data consists of initial observations and initial infos, as well as + - optionally - a render image. Args: - initial_observation: Obligatory. The initial observation. - initial_info: Optional. The initial info. - initial_state: Optional. The initial hidden state of a - model (`RLModule`) if the latter is stateful. - initial_render_image: Optional. An RGB uint8 image from rendering - the environment. + observation: The initial observation returned by `env.reset()`. + infos: An (optional) info dict returned by `env.reset()`. + render_image: Optional. An RGB uint8 image from rendering + the environment right after the reset. """ assert not self.is_done assert len(self.observations) == 0 @@ -186,48 +361,55 @@ def add_initial_observation( # Leave self.t (and self.t_started) at 0. assert self.t == self.t_started == 0 - initial_info = initial_info or {} + infos = infos or {} + + if self.observation_space is not None: + assert self.observation_space.contains(observation), ( + f"`observation` {observation} does NOT fit SingleAgentEpisode's " + f"observation_space: {self.observation_space}!" + ) + + self.observations.append(observation) + self.infos.append(infos) + if render_image is not None: + self.render_images.append(render_image) - self.observations.append(initial_observation) - self.states = initial_state - self.infos.append(initial_info) - if initial_render_image is not None: - self.render_images.append(initial_render_image) - # TODO (sven): Do we have to call validate here? It is our own function - # that manipulates the object. + # Validate our data. self.validate() - def add_timestep( + def add_env_step( self, observation: ObsType, action: ActType, reward: SupportsFloat, + infos: Optional[Dict[str, Any]] = None, *, - info: Optional[Dict[str, Any]] = None, - state=None, - is_terminated: bool = False, - is_truncated: bool = False, + terminated: bool = False, + truncated: bool = False, render_image: Optional[np.ndarray] = None, - extra_model_output: Optional[Dict[str, Any]] = None, + extra_model_outputs: Optional[Dict[str, Any]] = None, ) -> None: - """Adds a timestep to the episode. + """Adds results of an `env.step()` call (including the action) to this episode. + + This data consists of an observation and info dict, an action, a reward, + terminated/truncated flags, extra model outputs (e.g. action probabilities or + RNN internal state outputs), and - optionally - a render image. Args: - observation: The observation received from the - environment. - action: The last action used by the agent. - reward: The last reward received by the agent. - info: The last info recevied from the environment. - state: Optional. The last hidden state of the model (`RLModule` ). - This is only available, if the model is stateful. - is_terminated: A boolean indicating, if the environment has been - terminated. - is_truncated: A boolean indicating, if the environment has been - truncated. + observation: The observation received from the environment after taking + `action`. + action: The last action used by the agent during the call to `env.step()`. + reward: The last reward received by the agent after taking `action`. + infos: The last info received from the environment after taking `action`. + terminated: A boolean indicating, if the environment has been + terminated (after taking `action`). + truncated: A boolean indicating, if the environment has been + truncated (after taking `action`). render_image: Optional. An RGB uint8 image from rendering - the environment. - extra_model_output: The last timestep's specific model outputs - (e.g. `vf_preds` for PPO). + the environment (after taking `action`). + extra_model_outputs: The last timestep's specific model outputs. + These are normally outputs of an RLModule that were computed along with + `action`, e.g. `action_logp` or `action_dist_inputs`. """ # Cannot add data to an already done episode. assert ( @@ -237,55 +419,65 @@ def add_timestep( self.observations.append(observation) self.actions.append(action) self.rewards.append(reward) - info = info or {} - self.infos.append(info) - self.states = state + infos = infos or {} + self.infos.append(infos) self.t += 1 if render_image is not None: self.render_images.append(render_image) - if extra_model_output is not None: - for k, v in extra_model_output.items(): - if k not in self.extra_model_outputs: - self.extra_model_outputs[k] = [v] - else: - self.extra_model_outputs[k].append(v) - self.is_terminated = is_terminated - self.is_truncated = is_truncated + if extra_model_outputs is not None: + for k, v in extra_model_outputs.items(): + self.extra_model_outputs[k].append(v) + self.is_terminated = terminated + self.is_truncated = truncated + + # Validate our data. self.validate() + # Only check spaces every n timesteps. + if self.t % 50: + if self.observation_space is not None: + assert self.observation_space.contains(observation), ( + f"`observation` {observation} does NOT fit SingleAgentEpisode's " + f"observation_space: {self.observation_space}!" + ) + if self.action_space is not None: + assert self.action_space.contains(action), ( + f"`action` {action} does NOT fit SingleAgentEpisode's " + f"action_space: {self.action_space}!" + ) def validate(self) -> None: - """Validates the episode. + """Validates the episode's data. This function ensures that the data stored to a `SingleAgentEpisode` is in order (e.g. that the correct number of observations, actions, rewards are there). """ + assert len(self.observations) == len(self.infos) + if len(self.observations) == 0: + assert len(self.infos) == len(self.rewards) == len(self.actions) == 0 + for k, v in self.extra_model_outputs.items(): + assert len(v) == 0 # Make sure we always have one more obs stored than rewards (and actions) # due to the reset and last-obs logic of an MDP. - assert ( - len(self.observations) - == len(self.infos) - == len(self.rewards) + 1 - == len(self.actions) + 1 - ) - # TODO (sven): This is unclear to me. It makes sense - # to start at a point after the length of experiences - # provided at initialization, but when we test then here - # it will imo always error out. - # Example: we initialize the class by providing 101 observations, - # 100 actions and rewards. - # self.t = self.t_started = len(observations) - 1. Then - # we add a single timestep. self.t += 1 and - # self.t - self.t_started is 1, but len(rewards) is 100. - # assert len(self.rewards) == (self.t - self.t_started) - - if len(self.extra_model_outputs) > 0: + else: + assert ( + len(self.observations) + == len(self.infos) + == len(self.rewards) + 1 + == len(self.actions) + 1 + ) for k, v in self.extra_model_outputs.items(): assert len(v) == len(self.observations) - 1 - # Convert all lists to numpy arrays, if we are terminated. - if self.is_done: - self.convert_lists_to_numpy() + # Make sure, length of pre-buffer and len(self) make sense. + assert self._len_lookback_buffer + len(self) == len(self.rewards.data) + + @property + def is_finalized(self) -> bool: + """True, if the data in this episode is already stored as numpy arrays.""" + # If rewards are still a list, return False. + # Otherwise, rewards should already be a (1D) numpy array. + return self.rewards.finalized @property def is_done(self) -> bool: @@ -297,238 +489,727 @@ def is_done(self) -> bool: """ return self.is_terminated or self.is_truncated - def convert_lists_to_numpy(self) -> None: - """Converts list attributes to numpy arrays. + def finalize(self) -> "SingleAgentEpisode": + """Converts this Episode's list attributes to numpy arrays. + + This means in particular that this episodes' lists of (possibly complex) + data (e.g. if we have a dict obs space) will be converted to (possibly complex) + structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will + have the same length (batch dimension) as the length of the original lists. - When an episode is terminated or truncated (`self.is_done`) the data - will be not anymore touched and instead converted to numpy for later - use in postprocessing. This function converts all the data stored - into numpy arrays. + Note that SampleBatch.INFOS are NEVER numpy'ized and will remain a list + (normally, a list of the original, env-returned dicts). This is due to the + herterogenous nature of INFOS returned by envs, which would make it unwieldy to + convert this information to numpy arrays. + + After calling this method, no further data may be added to this episode via + the `self.add_env_step()` method. + + Examples: + + .. testcode:: + + import numpy as np + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + observations=[0, 1, 2, 3], + actions=[1, 2, 3], + rewards=[1, 2, 3], + # Note: terminated/truncated have nothing to do with an episode + # being `finalized` or not (via the `self.finalize()` method)! + terminated=False, + ) + # Episode has not been finalized (numpy'ized) yet. + assert not episode.is_finalized + # We are still operating on lists. + assert episode.get_observations([1]) == [1] + assert episode.get_observations(slice(None, 2)) == [0, 1] + # We can still add data (and even add the terminated=True flag). + episode.add_env_step( + observation=4, + action=4, + reward=4, + terminated=True, + ) + # Still NOT finalized. + assert not episode.is_finalized + + # Let's finalize the episode. + episode.finalize() + assert episode.is_finalized + + # We cannot add data anymore. The following would crash. + # episode.add_env_step(observation=5, action=5, reward=5) + + # Everything is now numpy arrays (with 0-axis of size + # B=[len of requested slice]). + assert isinstance(episode.get_observations([1]), np.ndarray) # B=1 + assert isinstance(episode.actions[0:2], np.ndarray) # B=2 + assert isinstance(episode.rewards[1:4], np.ndarray) # B=3 + + Returns: + This `SingleAgentEpisode` object with the converted numpy data. """ - self.observations = np.array(self.observations) - self.actions = np.array(self.actions) - self.rewards = np.array(self.rewards) - self.infos = np.array(self.infos) + self.observations.finalize() + self.actions.finalize() + self.rewards.finalize() self.render_images = np.array(self.render_images, dtype=np.uint8) for k, v in self.extra_model_outputs.items(): - self.extra_model_outputs[k] = np.array(v) + self.extra_model_outputs[k].finalize() - def create_successor(self) -> "SingleAgentEpisode": - """Returns a successor episode chunk (of len=0) continuing with this one. + return self - The successor will have the same ID and state as self and its only observation - will be the last observation in self. Its length will therefore be 0 (no - steps taken yet). + def cut(self, len_lookback_buffer: int = 0) -> "SingleAgentEpisode": + """Returns a successor episode chunk (of len=0) continuing from this Episode. + + The successor will have the same ID as `self`. + If no lookback buffer is requested (len_lookback_buffer=0), the successor's + observations will be the last observation(s) of `self` and its length will + therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0, + the returned successor will have `len_lookback_buffer` observations (and + actions, rewards, etc..) taken from the right side (end) of `self`. For example + if `len_lookback_buffer=2`, the returned successor's lookback buffer actions + will be identical to `self.actions[-2:]`. This method is useful if you would like to discontinue building an episode chunk (b/c you have to return it from somewhere), but would like to have a new - episode (chunk) instance to continue building the actual env episode at a later - time. + episode instance to continue building the actual gym.Env episode at a later + time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor) + will still be able to "look back" into this predecessor episode's data (at + least to some extend, depending on the value of `len_lookback_buffer`). + + Args: + len_lookback_buffer: The number of timesteps to take along into the new + chunk as "lookback buffer". A lookback buffer is additional data on + the left side of the actual episode data for visibility purposes + (but without actually being part of the new chunk). For example, if + `self` ends in actions 5, 6, 7, and 8, and we call + `self.cut(len_lookback_buffer=2)`, the returned chunk will have + actions 7 and 8 already in it, but still `t_started`==t==8 (not 7!) and + a length of 0. If there is not enough data in `self` yet to fulfil + the `len_lookback_buffer` request, the value of `len_lookback_buffer` + is automatically adjusted (lowered). Returns: The successor Episode chunk of this one with the same ID and state and the only observation being the last observation in self. """ - assert not self.is_done + assert not self.is_done and len_lookback_buffer >= 0 + + # Initialize this chunk with the most recent obs and infos (even if lookback is + # 0). Similar to an initial `env.reset()`. + indices_obs_and_infos = slice(-len_lookback_buffer - 1, None) + indices_rest = ( + slice(-len_lookback_buffer, None) + if len_lookback_buffer > 0 + else slice(None, 0) + ) return SingleAgentEpisode( # Same ID. id_=self.id_, - # First (and only) observation of successor is this episode's last obs. - observations=[self.observations[-1]], - # First (and only) info of successor is this episode's last info. - infos=[self.infos[-1]], - # Same state. - states=self.states, + observations=self.get_observations(indices=indices_obs_and_infos), + infos=self.get_infos(indices=indices_obs_and_infos), + actions=self.get_actions(indices=indices_rest), + rewards=self.get_rewards(indices=indices_rest), + extra_model_outputs={ + k: self.get_extra_model_outputs(k, indices_rest) + for k in self.extra_model_outputs.keys() + }, # Continue with self's current timestep. t_started=self.t, ) - def to_sample_batch(self) -> SampleBatch: - """Converts a `SingleAgentEpisode` into a `SampleBatch`. + def get_observations( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_indices_left_of_zero: bool = False, + fill: Optional[float] = None, + one_hot_discrete: bool = False, + ) -> Any: + """Returns individual observations or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual observation stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual observations in a batch of size len(indices). + A slice object is interpreted as a range of observations to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_indices_left_of_zero=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_indices_left_of_zero: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9], + where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will + respond to `get_observations(-1, neg_indices_left_of_zero=True)` + with `6` and to + `get_observations(slice(-2, 1), neg_indices_left_of_zero=True)` with + `[5, 6, 7]`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with observations [10, 11, 12, 13, 14] and + lookback buffer size of 2 (meaning observations `10` and `11` are part + of the lookback buffer) will respond to + `get_observations(slice(-7, -2), fill=0.0)` with + `[0.0, 0.0, 10, 11, 12]`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) observation + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). - Note that `RLlib` is relying in training on the `SampleBatch` class and - therefore episodes have to be converted to this format before training can - start. + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + # Discrete(4) observations (ints between 0 and 4 (excl.)) + observation_space=gym.spaces.Discrete(4), + observations=[0, 1, 2, 3], + actions=[1, 2, 3], rewards=[1, 2, 3], # <- not relevant for this demo + ) + # Plain usage (`indices` arg only). + episode.get_observations(-1) # 3 + episode.get_observations(0) # 0 + episode.get_observations([0, 2]) # [0, 2] + episode.get_observations([-1, 0]) # [3, 0] + episode.get_observations(slice(None, 2)) # [0, 1] + episode.get_observations(slice(-2, None)) # [2, 3] + # Using `fill=...` (requesting slices beyond the boundaries). + episode.get_observations(slice(-6, -2), fill=-9) # [-9, -9, 0, 1] + episode.get_observations(slice(2, 5), fill=-7) # [2, 3, -7] + # Using `one_hot_discrete=True`. + episode.get_observations(2, one_hot_discrete=True) # [0 0 1 0] + episode.get_observations(3, one_hot_discrete=True) # [0 0 0 1] + episode.get_observations( + slice(0, 3), + one_hot_discrete=True, + ) # [[1 0 0 0], [0 1 0 0], [0 0 1 0]] + # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. + episode.get_observations( + -1, + neg_indices_left_of_zero=True, # -1 means one left of ts=0 + fill=0.0, + one_hot_discrete=True, + ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!) Returns: - An `ray.rLlib.policy.sample_batch.SampleBatch` instance containing this - episode's data. + The collected observations. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. """ - return SampleBatch( - { - SampleBatch.EPS_ID: np.array([self.id_] * len(self)), - SampleBatch.OBS: self.observations[:-1], - SampleBatch.NEXT_OBS: self.observations[1:], - SampleBatch.ACTIONS: self.actions, - SampleBatch.REWARDS: self.rewards, - SampleBatch.T: list(range(self.t_started, self.t)), - SampleBatch.TERMINATEDS: np.array( - [False] * (len(self) - 1) + [self.is_terminated] - ), - SampleBatch.TRUNCATEDS: np.array( - [False] * (len(self) - 1) + [self.is_truncated] - ), - # Return the infos after stepping the environment. - SampleBatch.INFOS: self.infos[1:], - **self.extra_model_outputs, - } + return self.observations.get( + indices=indices, + neg_indices_left_of_zero=neg_indices_left_of_zero, + fill=fill, + one_hot_discrete=one_hot_discrete, ) - @staticmethod - def from_sample_batch(batch: SampleBatch) -> "SingleAgentEpisode": - """Converts a `SampleBatch` instance into a `SingleAegntEpisode`. + def get_infos( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_indices_left_of_zero: bool = False, + fill: Optional[Any] = None, + ) -> Any: + """Returns individual info dicts or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual info dict stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual info dicts in a list of size len(indices). + A slice object is interpreted as a range of info dicts to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_indices_left_of_zero=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_indices_left_of_zero: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with infos + [{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the + first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will + respond to `get_infos(-1, neg_indices_left_of_zero=True)` with + `{"l":6}` and to + `get_infos(slice(-2, 1), neg_indices_left_of_zero=True)` with + `[{"l":5}, {"l":6}, {"a":7}]`. + fill: An optional value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to + auto-fill. For example, an episode with infos + [{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer + size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback + buffer) will respond to `get_infos(slice(-7, -2), fill={"o": 0.0})` + with `[{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]`. + TODO (sven): This would require a space being provided. Maybe we can + skip this check for infos, which don't have a space anyways. + + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode - The `ray.rllib.policy.sample_batch.SampleBatch` class is used in `RLlib` - for training an agent's modules (`RLModule`), converting from or to - `SampleBatch` can be performed by this function and its counterpart - `to_sample_batch()`. + episode = SingleAgentEpisode( + infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}], + # The following is needed, but not relevant for this demo. + observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], + ) + # Plain usage (`indices` arg only). + episode.get_infos(-1) # {"d":3} + episode.get_infos(0) # {"a":0} + episode.get_infos([0, 2]) # [{"a":0},{"c":2}] + episode.get_infos([-1, 0]) # [{"d":3},{"a":0}] + episode.get_infos(slice(None, 2)) # [{"a":0},{"b":1}] + episode.get_infos(slice(-2, None)) # [{"c":2},{"d":3}] + # Using `fill=...` (requesting slices beyond the boundaries). + # TODO (sven): This would require a space being provided. Maybe we can + # skip this check for infos, which don't have a space anyways. + # episode.get_infos(slice(-5, -3), fill={"o":-1}) # [{"o":-1},{"a":0}] + # episode.get_infos(slice(3, 5), fill={"o":-2}) # [{"d":3},{"o":-2}] + + Returns: + The collected info dicts. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + return self.infos.get( + indices=indices, + neg_indices_left_of_zero=neg_indices_left_of_zero, + fill=fill, + ) + + def get_actions( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_indices_left_of_zero: bool = False, + fill: Optional[float] = None, + one_hot_discrete: bool = False, + ) -> Any: + """Returns individual actions or batched ranges thereof from this episode. Args: - batch: A `SampleBatch` instance. It should contain only a single episode. + indices: A single int is interpreted as an index, from which to return the + individual action stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual actions in a batch of size len(indices). + A slice object is interpreted as a range of actions to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_indices_left_of_zero=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_indices_left_of_zero: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where + [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond + to `get_actions(-1, neg_indices_left_of_zero=True)` with `6` and + to `get_actions(slice(-2, 1), neg_indices_left_of_zero=True)` with + `[5, 6, 7]`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with actions [10, 11, 12, 13, 14] and + lookback buffer size of 2 (meaning actions `10` and `11` are part + of the lookback buffer) will respond to + `get_actions(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) action + space that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + + Examples: + + .. testcode:: + + import gymnasium as gym + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + # Discrete(4) actions (ints between 0 and 4 (excl.)) + action_space=gym.spaces.Discrete(4), + actions=[1, 2, 3], + observations=[0, 1, 2, 3], rewards=[1, 2, 3], # <- not relevant here + ) + # Plain usage (`indices` arg only). + episode.get_actions(-1) # 3 + episode.get_actions(0) # 1 + episode.get_actions([0, 2]) # [1, 3] + episode.get_actions([-1, 0]) # [3, 1] + episode.get_actions(slice(None, 2)) # [1, 2] + episode.get_actions(slice(-2, None)) # [2, 3] + # Using `fill=...` (requesting slices beyond the boundaries). + episode.get_actions(slice(-5, -2), fill=-9) # [-9, -9, 1, 2] + episode.get_actions(slice(1, 5), fill=-7) # [2, 3, -7, -7] + # Using `one_hot_discrete=True`. + episode.get_actions(1, one_hot_discrete=True) # [0 0 1 0] (action=2) + episode.get_actions(2, one_hot_discrete=True) # [0 0 0 1] (action=3) + episode.get_actions( + slice(0, 2), + one_hot_discrete=True, + ) # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3) + # Special case: Using `fill=0.0` AND `one_hot_discrete=True`. + episode.get_actions( + -1, + neg_indices_left_of_zero=True, # -1 means one left of ts=0 + fill=0.0, + one_hot_discrete=True, + ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!) Returns: - An `SingleAegntEpisode` instance containing the data from `batch`. + The collected actions. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. """ - is_done = ( - batch[SampleBatch.TERMINATEDS][-1] or batch[SampleBatch.TRUNCATEDS][-1] + return self.actions.get( + indices=indices, + neg_indices_left_of_zero=neg_indices_left_of_zero, + fill=fill, + one_hot_discrete=one_hot_discrete, ) - observations = np.concatenate( - [batch[SampleBatch.OBS], batch[SampleBatch.NEXT_OBS][None, -1]] + + def get_rewards( + self, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_indices_left_of_zero: bool = False, + fill: Optional[Any] = None, + ) -> Any: + """Returns individual rewards or batched ranges thereof from this episode. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual reward stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual rewards in a batch of size len(indices). + A slice object is interpreted as a range of rewards to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_indices_left_of_zero=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_indices_left_of_zero: Negative values in `indices` are interpreted as + as "before ts=0", meaning going back into the lookback buffer. + For example, an episode with rewards [4, 5, 6, 7, 8, 9], where + [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond + to `get_rewards(-1, neg_indices_left_of_zero=True)` with `6` and + to `get_rewards(slice(-2, 1), neg_indices_left_of_zero=True)` with + `[5, 6, 7]`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with rewards [10, 11, 12, 13, 14] and + lookback buffer size of 2 (meaning rewards `10` and `11` are part + of the lookback buffer) will respond to + `get_rewards(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`. + + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + rewards=[1.0, 2.0, 3.0], + observations=[0, 1, 2, 3], actions=[1, 2, 3], # <- not relevant here + ) + # Plain usage (`indices` arg only). + episode.get_rewards(-1) # 3.0 + episode.get_rewards(0) # 1.0 + episode.get_rewards([0, 2]) # [1.0, 3.0] + episode.get_rewards([-1, 0]) # [3.0, 1.0] + episode.get_rewards(slice(None, 2)) # [1.0, 2.0] + episode.get_rewards(slice(-2, None)) # [2.0, 3.0] + # Using `fill=...` (requesting slices beyond the boundaries). + episode.get_rewards(slice(-5, -2), fill=0.0) # [0.0, 0.0, 1.0, 2.0] + episode.get_rewards(slice(1, 5), fill=0.0) # [2.0, 3.0, 0.0, 0.0] + + Returns: + The collected rewards. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + return self.rewards.get( + indices=indices, + neg_indices_left_of_zero=neg_indices_left_of_zero, + fill=fill, ) - actions = batch[SampleBatch.ACTIONS] - rewards = batch[SampleBatch.REWARDS] - # These are the infos after stepping the environment, i.e. without the - # initial info. - infos = batch[SampleBatch.INFOS] - # Concatenate an intiial empty info. - infos = np.concatenate([np.array([{}]), infos]) - - # TODO (simon): This is very ugly, but right now - # we can only do it according to the exclusion principle. - extra_model_output_keys = [] - for k in batch.keys(): - if k not in [ - SampleBatch.EPS_ID, - SampleBatch.AGENT_INDEX, - SampleBatch.ENV_ID, - SampleBatch.AGENT_INDEX, - SampleBatch.T, - SampleBatch.SEQ_LENS, - SampleBatch.OBS, - SampleBatch.INFOS, - SampleBatch.NEXT_OBS, - SampleBatch.ACTIONS, - SampleBatch.PREV_ACTIONS, - SampleBatch.REWARDS, - SampleBatch.PREV_REWARDS, - SampleBatch.TERMINATEDS, - SampleBatch.TRUNCATEDS, - SampleBatch.UNROLL_ID, - SampleBatch.DONES, - SampleBatch.CUR_OBS, - ]: - extra_model_output_keys.append(k) - return SingleAgentEpisode( - id_=batch[SampleBatch.EPS_ID][0], - observations=observations if is_done else observations.tolist(), - actions=actions if is_done else actions.tolist(), - rewards=rewards if is_done else rewards.tolist(), - t_started=batch[SampleBatch.T][0], - is_terminated=batch[SampleBatch.TERMINATEDS][-1], - is_truncated=batch[SampleBatch.TRUNCATEDS][-1], - infos=infos if is_done else infos.tolist(), - extra_model_outputs={ - k: (batch[k] if is_done else batch[k].tolist()) - for k in extra_model_output_keys - }, + def get_extra_model_outputs( + self, + key: str, + indices: Optional[Union[int, List[int], slice]] = None, + *, + neg_indices_left_of_zero: bool = False, + fill: Optional[Any] = None, + ) -> Any: + """Returns extra model outputs (under given key) from this episode. + + Args: + key: The `key` within `self.extra_model_outputs` to extract data for. + indices: A single int is interpreted as an index, from which to return an + individual extra model output stored under `key` at index. + A list of ints is interpreted as a list of indices from which to gather + individual actions in a batch of size len(indices). + A slice object is interpreted as a range of extra model outputs to be + returned. Thereby, negative indices by default are interpreted as + "before the end" unless the `neg_indices_left_of_zero=True` option is + used, in which case negative indices are interpreted as "before ts=0", + meaning going back into the lookback buffer. + neg_indices_left_of_zero: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an episode with + extra_model_outputs['a'] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the + lookback buffer range (ts=0 item is 7), will respond to + `get_extra_model_outputs("a", -1, neg_indices_left_of_zero=True)` with + `6` and to `get_extra_model_outputs("a", slice(-2, 1), + neg_indices_left_of_zero=True)` with `[5, 6, 7]`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the episode's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, an episode with + extra_model_outputs["b"] = [10, 11, 12, 13, 14] and lookback buffer + size of 2 (meaning `10` and `11` are part of the lookback buffer) will + respond to + `get_extra_model_outputs("b", slice(-7, -2), fill=0.0)` with + `[0.0, 0.0, 10, 11, 12]`. + TODO (sven): This would require a space being provided. Maybe we can + automatically infer the space from existing data? + + Examples: + + .. testcode:: + + from ray.rllib.env.single_agent_episode import SingleAgentEpisode + + episode = SingleAgentEpisode( + extra_model_outputs={"mo": [1, 2, 3]}, + # The following is needed, but not relevant for this demo. + observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3], + ) + + # Plain usage (`indices` arg only). + episode.get_extra_model_outputs("mo", -1) # 3 + episode.get_extra_model_outputs("mo", 1) # 0 + episode.get_extra_model_outputs("mo", [0, 2]) # [1, 3] + episode.get_extra_model_outputs("mo", [-1, 0]) # [3, 1] + episode.get_extra_model_outputs("mo", slice(None, 2)) # [1, 2] + episode.get_extra_model_outputs("mo", slice(-2, None)) # [2, 3] + # Using `fill=...` (requesting slices beyond the boundaries). + # TODO (sven): This would require a space being provided. Maybe we can + # automatically infer the space from existing data? + # episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0) # [0, 0, 1] + # episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1) # [3, -1, -1] + + Returns: + The collected extra_model_outputs[`key`]. + As a 0-axis batch, if there are several `indices` or a list of exactly one + index provided OR `indices` is a slice object. + As single item (B=0 -> no additional 0-axis) if `indices` is a single int. + """ + value = self.extra_model_outputs[key] + # The expected case is: `value` is a `BufferWithInfiniteLookback`. + if isinstance(value, BufferWithInfiniteLookback): + return value.get( + indices=indices, + neg_indices_left_of_zero=neg_indices_left_of_zero, + fill=fill, + ) + # It might be that the user has added new key/value pairs in their custom + # postprocessing/connector logic. The values are then most likely numpy + # arrays. We convert them automatically to buffers and get the requested + # indices (with the given options) from there. + return BufferWithInfiniteLookback(value).get( + indices, fill=fill, neg_indices_left_of_zero=neg_indices_left_of_zero ) - def get_return(self) -> float: - """Calculates an episode's return. + def slice(self, slice_: slice) -> "SingleAgentEpisode": + """Returns a slice of this episode with the given slice object. - The return is computed by a simple sum, neglecting the discount factor. - This is used predominantly for metrics. + For example, if `self` contains o0 (the reset observation), o1, o2, o3, and o4 + and the actions a1, a2, a3, and a4 (len of `self` is 4), then a call to + `self.slice(slice(1, 3))` would return a new SingleAgentEpisode with + observations o1, o2, and o3, and actions a2 and a3. Note here that there is + always one observation more in an episode than there are actions (and rewards + and extra model outputs) due to the initial observation received after an env + reset. + + Note that in any case, the lookback buffer will remain (if possible) at the same + size as it has been previously set to (`self._len_lookback_buffer`) and the + given slice object will NOT have to provide for this extra offset at the + beginning. + + Args: + slice_: The slice object to use for slicing. This should exclude the + lookback buffer, which will be prepended automatically to the returned + slice. Returns: - The sum of rewards collected during this episode. + The new SingleAgentEpisode representing the requested slice. """ - return sum(self.rewards) + # Figure out, whether slicing stops at the very end of this episode to know + # whether `self.is_terminated/is_truncated` should be kept as-is. + keep_done = slice_.stop is None or slice_.stop == len(self) + start = slice_.start or 0 + t_started = self.t_started + start + (0 if start >= 0 else len(self)) + + neg_indices_left_of_zero = (slice_.start or 0) >= 0 + slice_ = slice( + # Make sure that the lookback buffer is part of the new slice as well. + (slice_.start or 0) - self._len_lookback_buffer, + slice_.stop, + slice_.step, + ) + slice_obs_infos = slice( + slice_.start, + # Obs and infos need one more step at the end. + ((slice_.stop if slice_.stop != -1 else (len(self) - 1)) or len(self)) + 1, + slice_.step, + ) + return SingleAgentEpisode( + id_=self.id_, + observations=self.get_observations( + slice_obs_infos, + neg_indices_left_of_zero=neg_indices_left_of_zero, + ), + infos=self.get_infos( + slice_obs_infos, + neg_indices_left_of_zero=neg_indices_left_of_zero, + ), + actions=self.get_actions( + slice_, + neg_indices_left_of_zero=neg_indices_left_of_zero, + ), + rewards=self.get_rewards( + slice_, + neg_indices_left_of_zero=neg_indices_left_of_zero, + ), + extra_model_outputs={ + k: self.get_extra_model_outputs( + k, + slice_, + neg_indices_left_of_zero=neg_indices_left_of_zero, + ) + for k in self.extra_model_outputs + }, + terminated=(self.is_terminated if keep_done else False), + truncated=(self.is_truncated if keep_done else False), + # Provide correct timestep- and pre-buffer information. + t_started=t_started, + len_lookback_buffer=self._len_lookback_buffer, + ) - def get_state(self) -> Dict[str, Any]: - """Returns the pickable state of an episode. + def get_data_dict(self): + """Converts a SingleAgentEpisode into a data dict mapping str keys to data. - The data in the episode is stored into a dictionary. Note that episodes - can also be generated from states (see `self.from_state()`). + The keys used are: + SampleBatch.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS, + and those in `self.extra_model_outputs`. Returns: - A dictionary containing all the data from the episode. + A data dict mapping str keys to data records. """ - return list( + t = list(range(self.t_started, self.t)) + terminateds = [False] * (len(self) - 1) + [self.is_terminated] + truncateds = [False] * (len(self) - 1) + [self.is_truncated] + eps_id = [self.id_] * len(self) + + if self.is_finalized: + t = np.array(t) + terminateds = np.array(terminateds) + truncateds = np.array(truncateds) + eps_id = np.array(eps_id) + + return dict( { - "id_": self.id_, - "observations": self.observations, - "actions": self.actions, - "rewards": self.rewards, - "infos": self.infos, - "states": self.states, - "t_started": self.t_started, - "t": self.t, - "is_terminated": self.is_terminated, - "is_truncated": self.is_truncated, - **self.extra_model_outputs, - }.items() + # Trivial 1D data (compiled above). + SampleBatch.TERMINATEDS: terminateds, + SampleBatch.TRUNCATEDS: truncateds, + SampleBatch.T: t, + SampleBatch.EPS_ID: eps_id, + # Retrieve obs, infos, actions, rewards using our get_... APIs, + # which return all relevant timesteps (excluding the lookback + # buffer!). + SampleBatch.OBS: self.get_observations(slice(None, -1)), + SampleBatch.INFOS: self.get_infos(), + SampleBatch.ACTIONS: self.get_actions(), + SampleBatch.REWARDS: self.get_rewards(), + }, + # All `extra_model_outs`: Same as obs: Use get_... API. + **{ + k: self.get_extra_model_outputs(k) + for k in self.extra_model_outputs.keys() + }, ) - @staticmethod - def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode": - """Generates a `SingleAegntEpisode` from a pickable state. + def get_sample_batch(self) -> SampleBatch: + """Converts this `SingleAgentEpisode` into a `SampleBatch`. - The data in the state has to be complete. This is always the case when the state - was created by a `SingleAgentEpisode` itself calling `self.get_state()`. + Returns: + A SampleBatch containing all of this episode's data. + """ + return SampleBatch(self.get_data_dict()) - Args: - state: A dictionary containing all episode data. + def get_return(self) -> float: + """Calculates an episode's return, excluding the lookback buffer's rewards. + + The return is computed by a simple sum, neglecting the discount factor. Returns: - A `SingleAgentEpisode` instance holding all the data provided by `state`. + The sum of rewards collected during this episode, excluding possible data + inside the lookback buffer. """ - eps = SingleAgentEpisode(id_=state[0][1]) - eps.observations = state[1][1] - eps.actions = state[2][1] - eps.rewards = state[3][1] - eps.infos = state[4][1] - eps.states = state[5][1] - eps.t_started = state[6][1] - eps.t = state[7][1] - eps.is_terminated = state[8][1] - eps.is_truncated = state[9][1] - eps.extra_model_outputs = {k: v for k, v in state[10:]} - # Validate the episode to ensure complete data. - eps.validate() - return eps + return sum(self.get_rewards()) def __len__(self) -> int: """Returning the length of an episode. - The length of an episode is defined by the length of its data. This is the - number of timesteps an agent has stepped through an environment so far. - The length is undefined in case of a just started episode. + The length of an episode is defined by the length of its data, excluding + the lookback buffer data. The length is the number of timesteps an agent has + stepped through an environment thus far. + + The length is 0 in case of an episode whose env has NOT been reset yet, but + also 0 right after the `env.reset()` data has been added via + `self.add_env_reset()`. Only after the first call to `env.step()` (and + `self.add_env_step()`, the length will be 1. Returns: An integer, defining the length of an episode. - - Raises: - AssertionError: If episode has never been stepped so far. """ - assert len(self.observations) > 0, ( - "ERROR: Cannot determine length of episode that hasn't started yet! Call " - "`SingleAgentEpisode.add_initial_observation(initial_observation=...)` " - "first (after which `len(SingleAgentEpisode)` will be 0)." - ) - return len(self.observations) - 1 + return self.t - self.t_started def __repr__(self): return f"SAEps({self.id_} len={len(self)})" + + def __getitem__(self, item: slice) -> "SingleAgentEpisode": + """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:].""" + if isinstance(item, slice): + return self.slice(slice_=item) + else: + raise NotImplementedError( + f"SingleAgentEpisode does not support getting item '{item}'! " + "Only slice objects allowed with the syntax: `episode[a:b]`." + ) diff --git a/rllib/env/testing/__init__.py b/rllib/env/testing/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/rllib/env/testing/single_agent_gym_env_runner.py b/rllib/env/testing/single_agent_gym_env_runner.py deleted file mode 100644 index 2bf885b94c54..000000000000 --- a/rllib/env/testing/single_agent_gym_env_runner.py +++ /dev/null @@ -1,243 +0,0 @@ -from typing import List, Optional, Tuple - -import gymnasium as gym - -from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.env.env_runner import EnvRunner -from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.utils.annotations import override - - -class SingleAgentGymEnvRunner(EnvRunner): - """A simple single-agent EnvRunner subclass for testing purposes. - - Uses a gym.vector.Env environment and random actions. - """ - - def __init__(self, *, config: AlgorithmConfig, **kwargs): - """Initializes a SingleAgentGymEnvRunner instance. - - Args: - config: The config to use to setup this EnvRunner. - """ - super().__init__(config=config, **kwargs) - - # Create the gym.vector.Env object. - self.env = gym.vector.make( - self.config.env, - num_envs=self.config.num_envs_per_worker, - asynchronous=self.config.remote_worker_envs, - **dict(self.config.env_config, **{"render_mode": "rgb_array"}), - ) - self.num_envs = self.env.num_envs - - self._needs_initial_reset = True - self._episodes = [None for _ in range(self.num_envs)] - - @override(EnvRunner) - def sample( - self, - *, - num_timesteps: Optional[int] = None, - num_episodes: Optional[int] = None, - force_reset: bool = False, - **kwargs, - ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: - """Returns a tuple (list of completed episodes, list of ongoing episodes). - - Args: - num_timesteps: If provided, will step exactly this number of timesteps - through the environment. Note that only one or none of `num_timesteps` - and `num_episodes` may be provided, but never both. If both - `num_timesteps` and `num_episodes` are None, will determine how to - sample via `self.config`. - num_episodes: If provided, will step through the env(s) until exactly this - many episodes have been completed. Note that only one or none of - `num_timesteps` and `num_episodes` may be provided, but never both. - If both `num_timesteps` and `num_episodes` are None, will determine how - to sample via `self.config`. - force_reset: If True, will force-reset the env at the very beginning and - thus begin sampling from freshly started episodes. - **kwargs: Forward compatibility kwargs. - - Returns: - A tuple consisting of: A list of SingleAgentEpisode instances that are - already done (either terminated or truncated, hence their `is_done` property - is True), a list of SingleAgentEpisode instances that are still ongoing - (their `is_done` property is False). - """ - assert not (num_timesteps is not None and num_episodes is not None) - - # If no counters are provided, use our configured default settings. - if num_timesteps is None and num_episodes is None: - # Truncate episodes -> num_timesteps = rollout fragment * num_envs. - if self.config.batch_mode == "truncate_episodes": - num_timesteps = self.config.rollout_fragment_length * self.num_envs - # Complete episodes -> each env runs one episode. - else: - num_episodes = self.num_envs - - if num_timesteps is not None: - return self._sample_timesteps( - num_timesteps=num_timesteps, - force_reset=force_reset, - ) - else: - return self._sample_episodes(num_episodes=num_episodes) - - def _sample_timesteps( - self, - num_timesteps: int, - force_reset: bool = False, - ) -> Tuple[List[SingleAgentEpisode], List[SingleAgentEpisode]]: - """Runs n timesteps on the environment(s) and returns experiences. - - Timesteps are counted in total (across all vectorized sub-environments). For - example, if self.num_envs=2 and num_timesteps=10, each sub-environment - will be sampled for 5 steps. - """ - done_episodes_to_return = [] - - # Have to reset the env (on all vector sub-envs). - if force_reset or self._needs_initial_reset: - obs, _ = self.env.reset() - # Start new episodes. - # Set initial observations of the new episodes. - self._episodes = [ - SingleAgentEpisode(observations=[o]) for o in self._split_by_env(obs) - ] - self._needs_initial_reset = False - - # Loop for as long as we have not reached `num_timesteps` timesteps. - ts = 0 - while ts < num_timesteps: - # Act randomly. - actions = self.env.action_space.sample() - obs, rewards, terminateds, truncateds, infos = self.env.step(actions) - # Count timesteps across all environments. - ts += self.num_envs - - # Process env-returned data. - for i, (o, a, r, term, trunc) in enumerate( - zip( - self._split_by_env(obs), - self._split_by_env(actions), - self._split_by_env(rewards), - self._split_by_env(terminateds), - self._split_by_env(truncateds), - ) - ): - # Episode is done (terminated or truncated). - # Note that gym.vector.Env reset done sub-environments automatically - # (and start a new episode). The step-returned observation is then - # the new episode's reset observation and the final observation of - # the old episode can be found in the info dict. - if term or trunc: - # Finish the episode object with the actual terminal observation - # stored in the info dict (`o` is already the reset obs of the new - # episode). - self._episodes[i].add_timestep( - infos["final_observation"][i], - a, - r, - is_terminated=term, - is_truncated=trunc, - ) - # Add this finished episode to the list of completed ones. - done_episodes_to_return.append(self._episodes[i]) - # Start a new episode and set its initial observation to `o`. - self._episodes[i] = SingleAgentEpisode(observations=[o]) - # Episode is ongoing -> Add a timestep. - else: - self._episodes[i].add_timestep(o, a, r) - - # Return done episodes and all ongoing episode chunks, then start new episode - # chunks for those episodes that are still ongoing. - ongoing_episodes = self._episodes - # Create new chunks (using the same IDs and latest observations). - self._episodes = [ - SingleAgentEpisode(id_=eps.id_, observations=[eps.observations[-1]]) - for eps in self._episodes - ] - # Return tuple: done episodes, ongoing ones. - return done_episodes_to_return, ongoing_episodes - - def _sample_episodes( - self, - num_episodes: int, - ): - """Runs n episodes (reset first) on the environment(s) and returns experiences. - - Episodes are counted in total (across all vectorized sub-environments). For - example, if self.num_envs=2 and num_episodes=10, each sub-environment - will run 5 episodes. - """ - done_episodes_to_return = [] - - obs, _ = self.env.reset() - episodes = [ - SingleAgentEpisode(observations=[o]) for o in self._split_by_env(obs) - ] - - eps = 0 - while eps < num_episodes: - actions = self.env.action_space.sample() - obs, rewards, terminateds, truncateds, infos = self.env.step(actions) - - for i, (o, a, r, term, trunc) in enumerate( - zip( - self._split_by_env(obs), - self._split_by_env(actions), - self._split_by_env(rewards), - self._split_by_env(terminateds), - self._split_by_env(truncateds), - ) - ): - # Episode is done (terminated or truncated). - # Note that gym.vector.Env reset done sub-environments automatically - # (and start a new episode). The step-returned observation is then - # the new episode's reset observation and the final observation of - # the old episode can be found in the info dict. - if term or trunc: - eps += 1 - # Finish the episode object with the actual terminal observation - # stored in the info dict. - episodes[i].add_timestep( - infos["final_observation"][i], - a, - r, - is_terminated=term, - is_truncated=trunc, - ) - # Add this finished episode to the list of completed ones. - done_episodes_to_return.append(episodes[i]) - - # Also early-out if we reach the number of episodes within this - # for-loop. - if eps == num_episodes: - break - - # Start a new episode and set its initial observation to `o`. - episodes[i] = SingleAgentEpisode(observations=[o]) - else: - episodes[i].add_timestep(o, a, r) - - # If user calls sample(num_timesteps=..) after this, we must reset again - # at the beginning. - self._needs_initial_reset = True - - # Return 2 lists: finished and ongoing episodes. - return done_episodes_to_return, [] - - @override(EnvRunner) - def assert_healthy(self): - # Make sure, we have built our gym.vector.Env properly. - assert self.env - - @override(EnvRunner) - def stop(self): - # Close our env object via gymnasium's API. - self.env.close() - - def _split_by_env(self, inputs): - return [inputs[i] for i in range(self.num_envs)] diff --git a/rllib/env/tests/test_lookback_buffer.py b/rllib/env/tests/test_lookback_buffer.py new file mode 100644 index 000000000000..702585e29104 --- /dev/null +++ b/rllib/env/tests/test_lookback_buffer.py @@ -0,0 +1,491 @@ +import unittest + +import gymnasium as gym +import numpy as np + +from ray.rllib.env.utils import BufferWithInfiniteLookback +from ray.rllib.utils.spaces.space_utils import batch, get_dummy_batch_for_space +from ray.rllib.utils.test_utils import check + + +class TestBufferWithInfiniteLookback(unittest.TestCase): + space = gym.spaces.Dict( + { + "a": gym.spaces.Discrete(4), + "b": gym.spaces.Box(-1.0, 1.0, (2, 3)), + "c": gym.spaces.Tuple( + [gym.spaces.MultiDiscrete([2, 3]), gym.spaces.Box(-1.0, 1.0, (1,))] + ), + } + ) + + def test_append_and_pop(self): + buffer = BufferWithInfiniteLookback(data=[0, 1, 2, 3]) + self.assertTrue(len(buffer), 4) + buffer.append(4) + self.assertTrue(len(buffer), 5) + buffer.append(5) + self.assertTrue(len(buffer), 6) + buffer.pop() + self.assertTrue(len(buffer), 5) + buffer.pop() + self.assertTrue(len(buffer), 4) + buffer.append(10) + self.assertTrue(len(buffer), 5) + check(buffer.data, [0, 1, 2, 3, 10]) + buffer.finalize() + self.assertRaises(RuntimeError, lambda: buffer.append("something")) + self.assertRaises(RuntimeError, lambda: buffer.extend(["something"])) + self.assertRaises(RuntimeError, lambda: buffer.pop()) + + def test_complex_structs(self): + buffer = BufferWithInfiniteLookback( + data=[self.space.sample() for _ in range(4)] + ) + self.assertTrue(len(buffer), 4) + buffer.append(self.space.sample()) + self.assertTrue(len(buffer), 5) + buffer.append(self.space.sample()) + self.assertTrue(len(buffer), 6) + + buffer.finalize() + self.assertRaises(RuntimeError, lambda: buffer.append("something")) + + self.assertTrue(isinstance(buffer.data, dict)) + self.assertTrue(isinstance(buffer.data["a"], np.ndarray)) + self.assertTrue(isinstance(buffer.data["b"], np.ndarray)) + self.assertTrue(isinstance(buffer.data["c"], tuple)) + self.assertTrue(isinstance(buffer.data["c"][0], np.ndarray)) + self.assertTrue(isinstance(buffer.data["c"][1], np.ndarray)) + + def test_lookback(self): + buffer = BufferWithInfiniteLookback(data=[0, 1, 2, 3], lookback=2) + self.assertTrue(len(buffer), 2) + data_no_lookback = buffer.get() + check(data_no_lookback, [2, 3]) + buffer.append(4) + self.assertTrue(len(buffer), 3) + buffer.append(5) + self.assertTrue(len(buffer), 4) + data_no_lookback = buffer.get() + check(data_no_lookback, [2, 3, 4, 5]) + buffer.pop() + self.assertTrue(len(buffer), 3) + data_no_lookback = buffer.get() + check(data_no_lookback, [2, 3, 4]) + + def test_get_with_lookback(self): + """Tests `get` and `getitem` functionalities with a lookback range > 0.""" + buffer = BufferWithInfiniteLookback(data=[0, 1, 2, 3, 4], lookback=2) + + # Test on ongoing and finalized buffer. + for finalized in [False, True]: + if finalized: + buffer.finalize() + + self.assertTrue(len(buffer), 3) + # No args: Expect all contents excluding lookback buffer. + check(buffer.get(), [2, 3, 4]) + check(buffer[:], [2, 3, 4]) + # Individual negative indices (include lookback buffer). + check(buffer.get(-1), 4) + check(buffer[-1], 4) + check(buffer.get(-2), 3) + check(buffer[-2], 3) + check(buffer.get(-4), 1) + check(buffer[-4], 1) + check(buffer.get([-4]), [1]) + check(buffer[-4:-3], [1]) + self.assertRaises(IndexError, lambda: buffer.get(-6)) + self.assertRaises(IndexError, lambda: buffer[-6]) + self.assertRaises(IndexError, lambda: buffer.get(-1000)) + self.assertRaises(IndexError, lambda: buffer[-1000]) + # Individual positive indices (do NOT include lookback buffer). + check(buffer.get(0), 2) + check(buffer[0], 2) + check(buffer.get(1), 3) + check(buffer[1], 3) + check(buffer.get(2), 4) + check(buffer[2], 4) + check(buffer.get([2]), [4]) + check(buffer[2:3], [4]) + self.assertRaises(IndexError, lambda: buffer.get(3)) + self.assertRaises(IndexError, lambda: buffer[3]) + self.assertRaises(IndexError, lambda: buffer.get(1000)) + self.assertRaises(IndexError, lambda: buffer[1000]) + # List of negative indices (include lookback buffer). + check(buffer.get([-4, -5]), [1, 0]) + check(buffer.get([-1]), [4]) + check(buffer[-1:], [4]) + check(buffer.get([-5]), [0]) + check(buffer[-5:-4], [0]) + self.assertRaises(IndexError, lambda: buffer.get([-6])) + self.assertRaises(IndexError, lambda: buffer.get([-1, -6])) + self.assertRaises(IndexError, lambda: buffer.get([-1000])) + # List of positive indices (do NOT include lookback buffer). + check(buffer.get([1, 0, 2]), [3, 2, 4]) + check(buffer.get([0, 2, 1]), [2, 4, 3]) + self.assertRaises(IndexError, lambda: buffer.get([6])) + self.assertRaises(IndexError, lambda: buffer.get([1, 6])) + self.assertRaises(IndexError, lambda: buffer.get([1000])) + # List of positive and negative indices (do NOT include lookback buffer). + check(buffer.get([1, 0, -2]), [3, 2, 3]) + check(buffer.get([-3, 1, -1]), [2, 3, 4]) + # Slices. + # Type: [None:...] + check(buffer.get(slice(None, None)), [2, 3, 4]) + check(buffer.get(slice(None, 2)), [2, 3]) + check(buffer[:2], [2, 3]) + check(buffer.get(slice(3)), [2, 3, 4]) + check(buffer[:3], [2, 3, 4]) + check(buffer.get(slice(None, -1)), [2, 3]) + check(buffer[:-1], [2, 3]) + check(buffer.get(slice(None, -2)), [2]) + check(buffer[:-2], [2]) + # Type: [...:None] + check(buffer.get(slice(2, None)), [4]) + check(buffer[2:], [4]) + check(buffer.get(slice(2, 5)), [4]) + check(buffer[2:5], [4]) + check(buffer.get(slice(1, None)), [3, 4]) + check(buffer[1:], [3, 4]) + check(buffer.get(slice(1, 5)), [3, 4]) + check(buffer[1:5], [3, 4]) + check(buffer.get(slice(-1, None)), [4]) + check(buffer[-1:], [4]) + check(buffer.get(slice(-1, 5)), [4]) + check(buffer[-1:5], [4]) + check(buffer.get(slice(-4, None)), [1, 2, 3, 4]) + check(buffer[-4:], [1, 2, 3, 4]) + check(buffer.get(slice(-4, 5)), [1, 2, 3, 4]) + check(buffer[-4:5], [1, 2, 3, 4]) + # Type: [-n:-m] + check(buffer.get(slice(-2, -1)), [3]) + check(buffer[-2:-1], [3]) + check(buffer.get(slice(-3, -1)), [2, 3]) + check(buffer[-3:-1], [2, 3]) + check(buffer.get(slice(-4, -2)), [1, 2]) + check(buffer[-4:-2], [1, 2]) + check(buffer.get(slice(-4, -1)), [1, 2, 3]) + check(buffer[-4:-1], [1, 2, 3]) + check(buffer.get(slice(-5, -1)), [0, 1, 2, 3]) + check(buffer[-5:-1], [0, 1, 2, 3]) + check(buffer.get(slice(-6, -2)), [0, 1, 2]) + check(buffer[-6:-2], [0, 1, 2]) + # Type: [+n:+m] + check(buffer.get(slice(0, 1)), [2]) + check(buffer[0:1], [2]) + check(buffer.get(slice(0, 2)), [2, 3]) + check(buffer[0:2], [2, 3]) + check(buffer.get(slice(0, 3)), [2, 3, 4]) + check(buffer[0:3], [2, 3, 4]) + check(buffer.get(slice(1, 2)), [3]) + check(buffer[1:2], [3]) + check(buffer.get(slice(1, 3)), [3, 4]) + check(buffer[1:3], [3, 4]) + check(buffer.get(slice(2, 3)), [4]) + check(buffer[2:3], [4]) + # Type: [+n:-m] + check(buffer.get(slice(0, -1)), [2, 3]) + check(buffer[0:-1], [2, 3]) + check(buffer.get(slice(0, -2)), [2]) + check(buffer[0:-2], [2]) + check(buffer.get(slice(1, -1)), [3]) + check(buffer[1:-1], [3]) + + # Check the type on the finalized buffer (numpy arrays). + data = buffer.get([1, 0, 2]) + self.assertTrue(isinstance(data, np.ndarray)) + check(data, [3, 2, 4]) + + def test_get_with_lookback_and_fill(self): + """Tests the `fill` argument of `get` with a lookback range >0.""" + buffer = BufferWithInfiniteLookback( + data=[0, 1, 2, 3, 4, 5], + lookback=3, + # Specify a space, so we can fill and one-hot discrete data properly. + space=gym.spaces.Discrete(6), + ) + + # Test on ongoing and finalized buffer. + for finalized in [False, True]: + if finalized: + buffer.finalize() + + self.assertTrue(len(buffer), 3) + + # Individual indices with fill. + check(buffer.get(-10, fill=10), 10) + check(buffer.get(-3, fill=10), 3) + check(buffer.get(-2, fill=10), 4) + check(buffer.get(-1, fill=10), 5) + check(buffer.get(0, fill=10), 3) + check(buffer.get(2, fill=10), 5) + check(buffer.get(100, fill=10), 10) + + # Left fill. + check(buffer.get(slice(-8, None), fill=10), [10, 10, 0, 1, 2, 3, 4, 5]) + check(buffer.get(slice(-9, None), fill=10), [10, 10, 10, 0, 1, 2, 3, 4, 5]) + check( + buffer.get(slice(-10, None), fill=11), + [11, 11, 11, 11, 0, 1, 2, 3, 4, 5], + ) + check(buffer.get(slice(-10, -4), fill=11), [11, 11, 11, 11, 0, 1]) + # Both start stop on left side. + check(buffer.get(slice(-10, -9), fill=0), [0]) + check(buffer.get(slice(-20, -15), fill=0), [0, 0, 0, 0, 0]) + check(buffer.get(slice(-1001, -1000), fill=6), [6]) + # Both start stop on right side. + check(buffer.get(slice(10, 15), fill=0), [0, 0, 0, 0, 0]) + check(buffer.get(slice(15, 17), fill=0), [0, 0]) + check(buffer.get(slice(1000, 1001), fill=6), [6]) + # Right fill. + check(buffer.get(slice(2, 8), fill=12), [5, 12, 12, 12, 12, 12]) + check(buffer.get(slice(1, 7), fill=13), [4, 5, 13, 13, 13, 13]) + check(buffer.get(slice(1, 5), fill=-14), [4, 5, -14, -14]) + # No fill necessary (even though requested). + check(buffer.get(slice(-5, None), fill=999), [1, 2, 3, 4, 5]) + check(buffer.get(slice(-6, -1), fill=999), [0, 1, 2, 3, 4]) + check(buffer.get(slice(0, 2), fill=999), [3, 4]) + check(buffer.get(slice(1, None), fill=999), [4, 5]) + check(buffer.get(slice(None, 3), fill=999), [3, 4, 5]) + + # Check the type on the finalized buffer (numpy arrays). + data = buffer.get(slice(15, 17), fill=0) + self.assertTrue(isinstance(data, np.ndarray)) + check(data, [0, 0]) + + def test_get_with_fill_and_neg_indices_into_lookback_buffer(self): + """Tests the `fill` argument of `get` with a lookback range >0.""" + buffer = BufferWithInfiniteLookback( + data=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + lookback=4, + # Specify a space, so we can fill and one-hot discrete data properly. + space=gym.spaces.Discrete(11), + ) + + # Test on ongoing and finalized buffer. + for finalized in [False, True]: + if finalized: + buffer.finalize() + + self.assertTrue(len(buffer), 7) + + # Lokback buffer is [0, 1, 2, 3] + # Individual indices with negative indices into lookback buffer. + check(buffer.get(-1, neg_indices_left_of_zero=True), 3) + check(buffer.get(-2, neg_indices_left_of_zero=True), 2) + check(buffer.get(-3, neg_indices_left_of_zero=True), 1) + check(buffer.get(-4, neg_indices_left_of_zero=True), 0) + check(buffer.get([-1, -3], neg_indices_left_of_zero=True), [3, 1]) + # Slices with negative indices into lookback buffer. + check(buffer.get(slice(-2, -1), neg_indices_left_of_zero=True), [2]) + check(buffer.get(slice(-2, 0), neg_indices_left_of_zero=True), [2, 3]) + check( + buffer.get(slice(-2, 4), neg_indices_left_of_zero=True), + [2, 3, 4, 5, 6, 7], + ) + check( + buffer.get(slice(-2, None), neg_indices_left_of_zero=True), + [2, 3, 4, 5, 6, 7, 8, 9, 10], + ) + # With left fill. + check(buffer.get(-8, fill=10, neg_indices_left_of_zero=True), 10) + check(buffer.get(-800, fill=10, neg_indices_left_of_zero=True), 10) + check(buffer.get([-8, -1], fill=9, neg_indices_left_of_zero=True), [9, 3]) + check( + buffer.get(slice(-8, 0), fill=10, neg_indices_left_of_zero=True), + [10, 10, 10, 10, 0, 1, 2, 3], + ) + check( + buffer.get(slice(-7, 1), fill=10, neg_indices_left_of_zero=True), + [10, 10, 10, 0, 1, 2, 3, 4], + ) + check( + buffer.get(slice(-6, None), fill=11, neg_indices_left_of_zero=True), + [11, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ) + check( + buffer.get(slice(-10, -4), fill=11, neg_indices_left_of_zero=True), + [11, 11, 11, 11, 11, 11], + ) + # Both start stop on left side. + check( + buffer.get(slice(-10, -9), fill=0, neg_indices_left_of_zero=True), [0] + ) + check( + buffer.get(slice(-20, -15), fill=0, neg_indices_left_of_zero=True), + [0, 0, 0, 0, 0], + ) + check( + buffer.get(slice(-1001, -1000), fill=6, neg_indices_left_of_zero=True), + [6], + ) + # Both start stop on right side. + check( + buffer.get(slice(10, 15), fill=0, neg_indices_left_of_zero=True), + [0, 0, 0, 0, 0], + ) + check( + buffer.get(slice(15, 17), fill=0, neg_indices_left_of_zero=True), [0, 0] + ) + check( + buffer.get(slice(1000, 1001), fill=6, neg_indices_left_of_zero=True), + [6], + ) + # Right fill. + check(buffer.get(8, fill=10, neg_indices_left_of_zero=True), 10) + check(buffer.get(800, fill=10, neg_indices_left_of_zero=True), 10) + check(buffer.get([8, 1], fill=9, neg_indices_left_of_zero=True), [9, 5]) + check( + buffer.get(slice(-2, 8), fill=12, neg_indices_left_of_zero=True), + [2, 3, 4, 5, 6, 7, 8, 9, 10, 12], + ) + check( + buffer.get(slice(-1, 9), fill=13, neg_indices_left_of_zero=True), + [3, 4, 5, 6, 7, 8, 9, 10, 13, 13], + ) + check( + buffer.get(slice(-1, 5), fill=-14, neg_indices_left_of_zero=True), + [3, 4, 5, 6, 7, 8], + ) + # No fill necessary (even though requested). + check( + buffer.get(slice(-1, None), fill=999, neg_indices_left_of_zero=True), + [3, 4, 5, 6, 7, 8, 9, 10], + ) + check( + buffer.get(slice(-4, -1), fill=999, neg_indices_left_of_zero=True), + [0, 1, 2], + ) + check( + buffer.get(slice(-1, 2), fill=999, neg_indices_left_of_zero=True), + [3, 4, 5], + ) + + # Check the type on the finalized buffer (numpy arrays). + data = buffer.get(slice(-17, -15), fill=0, neg_indices_left_of_zero=True) + self.assertTrue(isinstance(data, np.ndarray)) + check(data, [0, 0]) + data = buffer.get([-3, -1], fill=0, neg_indices_left_of_zero=True) + self.assertTrue(isinstance(data, np.ndarray)) + check(data, [1, 3]) + + def test_get_with_fill_0_and_zero_hot(self): + """Tests, whether zero-hot is properly done when fill=0.""" + buffer = BufferWithInfiniteLookback( + data=[0, 1, 0, 1], + # Specify a space, so we can fill and one-hot discrete data properly. + space=gym.spaces.Discrete(2), + ) + + # Test on ongoing and finalized buffer. + for finalized in [False, True]: + if finalized: + buffer.finalize() + + self.assertTrue(len(buffer), 4) + + # Right side fill 0. Should be zero-hot. + check(buffer.get(4, fill=0, one_hot_discrete=True), [0, 0]) + check( + buffer.get( + -1, + neg_indices_left_of_zero=True, + fill=0, + one_hot_discrete=True, + ), + [0, 0], + ) + + def test_get_with_complex_space(self): + """Tests, whether zero-hot is properly done when fill=0.""" + buffer = BufferWithInfiniteLookback( + data=[ + get_dummy_batch_for_space( + space=self.space, + batch_size=0, + fill_value=float(i), + ) + for i in range(4) + ], + lookback=2, + # Specify a space, so we can fill and one-hot discrete data properly. + space=self.space, + ) + + buffer_0 = { + "a": 0, + "b": np.array([[0, 0, 0], [0, 0, 0]]), + "c": (np.array([0, 0]), np.array([0])), + } + buffer_0_one_hot = { + "a": np.array([0.0, 0.0, 0.0, 0.0]), + "b": np.array([[0, 0, 0], [0, 0, 0]]), + "c": (np.array([0, 0, 0, 0, 0]), np.array([0])), + } + buffer_1 = { + "a": 1, + "b": np.array([[1, 1, 1], [1, 1, 1]]), + "c": (np.array([1, 1]), np.array([1])), + } + buffer_2 = { + "a": 2, + "b": np.array([[2, 2, 2], [2, 2, 2]]), + "c": (np.array([2, 2]), np.array([2])), + } + buffer_3 = { + "a": 3, + "b": np.array([[3, 3, 3], [3, 3, 3]]), + "c": (np.array([3, 3]), np.array([3])), + } + + # Test on ongoing and finalized buffer. + for finalized in [False, True]: + if finalized: + buffer.finalize() + + def batch_(s): + return batch(s) + + else: + + def batch_(s): + return s + + self.assertTrue(len(buffer), 2) + + check(buffer.get(-1), buffer_3) + check(buffer.get(-2), buffer_2) + check(buffer.get(-3), buffer_1) + check(buffer.get(-4), buffer_0) + check(buffer.get(-5, fill=0.0), buffer_0) + check(buffer.get([-5, 5], fill=0.0), batch_([buffer_0, buffer_0])) + check(buffer.get([-5, 1], fill=0.0), batch_([buffer_0, buffer_3])) + check(buffer.get([1, -10], fill=0.0), batch_([buffer_3, buffer_0])) + check( + buffer.get([-10], fill=0.0, one_hot_discrete=True), + batch_([buffer_0_one_hot]), + ) + check(buffer.get(slice(0, 1), fill=0.0), batch_([buffer_2])) + check(buffer.get(slice(1, 3), fill=0.0), batch_([buffer_3, buffer_0])) + check(buffer.get(slice(-10, -12), fill=0.0), batch_([buffer_0, buffer_0])) + check( + buffer.get(slice(-10, -12), fill=0.0, neg_indices_left_of_zero=True), + batch_([buffer_0, buffer_0]), + ) + check( + buffer.get(slice(100, 98), fill=0.0, neg_indices_left_of_zero=True), + batch_([buffer_0, buffer_0]), + ) + check( + buffer.get(slice(100, 98), fill=0.0), + batch_([buffer_0, buffer_0]), + ) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index 8405f454faee..a53f3ffa16e6 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -150,8 +150,8 @@ def test_init(self): rewards = [] actions = [] infos = [] - is_terminateds = [] - is_truncateds = [] + terminateds = {} + truncateds = {} extra_model_outputs = [] # Initialize observation and info. obs, info = env.reset(seed=0) @@ -159,33 +159,31 @@ def test_init(self): infos.append(info) # Run 100 samples. for i in range(100): - agents_stepped = list(obs.keys()) - action = { - agent_id: i + 1 - for agent_id in agents_stepped - if agent_id in env._agents_alive - } + agents_to_step_next = [ + aid for aid in obs.keys() if aid in env._agents_alive + ] + action = {agent_id: i + 1 for agent_id in agents_to_step_next} # action = env.action_space_sample(agents_stepped) - obs, reward, is_terminated, is_truncated, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) observations.append(obs) actions.append(action) rewards.append(reward) infos.append(info) - is_terminateds.append(is_terminated) - is_truncateds.append(is_truncated) + terminateds.update(terminated) + truncateds.update(truncated) extra_model_outputs.append( - {agent_id: {"extra_1": 10.5} for agent_id in agents_stepped} + {agent_id: {"extra_1": 10.5} for agent_id in agents_to_step_next} ) # Now create the episode from the recorded data. episode = MultiAgentEpisode( - agent_ids=env.get_agent_ids(), + agent_ids=list(env.get_agent_ids()), observations=observations, actions=actions, rewards=rewards, infos=infos, - is_terminated=is_terminateds, - is_truncated=is_truncateds, + terminateds=terminateds, + truncateds=truncateds, extra_model_outputs=extra_model_outputs, ) @@ -214,8 +212,8 @@ def test_init(self): rewards=rewards[-10:], infos=infos[-11:], t_started=100, - is_terminated=is_terminateds[-10:], - is_truncated=is_truncateds[-10:], + terminateds=terminateds, + truncateds=truncateds, extra_model_outputs=extra_model_outputs[-10:], ) @@ -239,8 +237,8 @@ def test_init(self): observations, actions, rewards, - is_terminated, - is_truncated, + terminateds, + truncateds, infos, ) = self._mock_multi_agent_records() @@ -251,8 +249,8 @@ def test_init(self): actions=actions, rewards=rewards, infos=infos, - is_terminated=is_terminated, - is_truncated=is_truncated, + terminateds=terminateds, + truncateds=truncateds, ) # Assert that the length of `SingleAgentEpisode`s are all correct. @@ -276,9 +274,9 @@ def test_add_initial_observation(self): # Generate initial observations and infos and add them to the episode. obs, infos = env.reset(seed=0) - episode.add_initial_observation( - initial_observation=obs, - initial_info=infos, + episode.add_env_reset( + observations=obs, + infos=infos, ) # Assert that timestep is at zero. @@ -295,15 +293,15 @@ def test_add_initial_observation(self): # TODO (simon): Test the buffers and reward storage. - def test_add_timestep(self): + def test_add_env_step(self): # Create an environment and add the initial observations, infos, and states. env = MultiAgentTestEnv() episode = MultiAgentEpisode(agent_ids=env.get_agent_ids()) obs, infos = env.reset(seed=0) - episode.add_initial_observation( - initial_observation=obs, - initial_info=infos, + episode.add_env_reset( + observations=obs, + infos=infos, ) # Sample 100 timesteps and add them to the episode. @@ -311,16 +309,16 @@ def test_add_timestep(self): action = { agent_id: i + 1 for agent_id in obs if agent_id in env._agents_alive } - obs, reward, is_terminated, is_truncated, info = env.step(action) - - episode.add_timestep( - observation=obs, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={agent_id: {"extra": 10.5} for agent_id in action}, + obs, reward, terminated, truncated, info = env.step(action) + + episode.add_env_step( + observations=obs, + actions=action, + rewards=reward, + infos=info, + terminateds=terminated, + truncateds=truncated, + extra_model_outputs={agent_id: {"extra": 10.5} for agent_id in action}, ) # Assert that the timestep is at 100. @@ -349,19 +347,19 @@ def test_add_timestep(self): action = { agent_id: i + 1 for agent_id in obs if agent_id in env._agents_alive } - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode.add_timestep( - observation=obs, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={agent_id: {"extra": 10.5} for agent_id in action}, + obs, reward, terminated, truncated, info = env.step(action) + episode.add_env_step( + observations=obs, + actions=action, + rewards=reward, + infos=info, + terminateds=terminated, + truncateds=truncated, + extra_model_outputs={agent_id: {"extra": 10.5} for agent_id in action}, ) # Assert that the environment is done. - self.assertTrue(is_truncated["__all__"]) + self.assertTrue(truncated["__all__"]) # Assert that each agent is done. for agent_id in episode._agent_ids: self.assertTrue(episode.agent_episodes[agent_id].is_done) @@ -378,8 +376,8 @@ def test_add_timestep(self): observations, actions, rewards, - is_terminated, - is_truncated, + terminated, + truncated, infos, ) = self._mock_multi_agent_records() @@ -389,25 +387,25 @@ def test_add_timestep(self): actions=actions, rewards=rewards, infos=infos, - is_terminated=is_terminated, - is_truncated=is_truncated, + terminateds=terminated, + truncateds=truncated, ) # Now test that intermediate rewards will get recorded and actions buffered. action = {"agent_2": 3, "agent_4": 3} observation = {"agent_1": 3, "agent_2": 3} reward = {"agent_1": 1.0, "agent_2": 1.0, "agent_3": 1.0, "agent_5": 1.0} infos = {"agent_1": {}, "agent_2": {}} - is_terminated = {k: False for k in observation.keys()} - is_terminated.update({"__all__": False}) - is_truncated = {k: False for k in observation.keys()} - is_truncated.update({"__all__": False}) - episode.add_timestep( - observation=observation, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, + terminated = {k: False for k in observation.keys()} + terminated.update({"__all__": False}) + truncated = {k: False for k in observation.keys()} + truncated.update({"__all__": False}) + episode.add_env_step( + observations=observation, + actions=action, + rewards=reward, + infos=infos, + terminateds=terminated, + truncateds=truncated, ) # Assert that the action buffer for agent 4 is full. # Note, agent 4 acts, but receives no observation. @@ -420,7 +418,7 @@ def test_add_timestep(self): episode.agent_buffers["agent_3"]["rewards"].put_nowait(1.0) episode.agent_buffers["agent_5"]["rewards"].put_nowait(1.0) - def test_create_successor(self): + def test_cut(self): # Create an environment. episode_1, _ = self._mock_multi_agent_records_from_env(size=100) @@ -428,7 +426,7 @@ def test_create_successor(self): self.assertEqual(episode_1.t, 100) # Create a successor. - episode_2 = episode_1.create_successor() + episode_2 = episode_1.cut() # Assert that it has the same id. self.assertEqual(episode_1.id_, episode_2.id_) # Assert that all `SingleAgentEpisode`s have identical ids. @@ -575,8 +573,8 @@ def test_create_successor(self): observations, actions, rewards, - is_terminated, - is_truncated, + terminateds, + truncateds, infos, ) = self._mock_multi_agent_records() @@ -587,8 +585,8 @@ def test_create_successor(self): actions=actions, rewards=rewards, infos=infos, - is_terminated=is_terminated, - is_truncated=is_truncated, + terminateds=terminateds, + truncateds=truncateds, ) # Assert that agents 1 and 3's buffers are indeed full. @@ -611,17 +609,17 @@ def test_create_successor(self): # add this to the buffer and to the global reward history. reward = {"agent_1": 2.0, "agent_2": 2.0, "agent_3": 2.0, "agent_5": 2.0} info = {"agent_1": {}, "agent_2": {}} - is_terminated = {k: False for k in observation.keys()} - is_terminated.update({"__all__": False}) - is_truncated = {k: False for k in observation.keys()} - is_truncated.update({"__all__": False}) - episode_1.add_timestep( - observation=observation, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, + terminateds = {k: False for k in observation.keys()} + terminateds.update({"__all__": False}) + truncateds = {k: False for k in observation.keys()} + truncateds.update({"__all__": False}) + episode_1.add_env_step( + observations=observation, + actions=action, + rewards=reward, + infos=info, + terminateds=terminateds, + truncateds=truncateds, ) # Check that the partial reward history is correct. @@ -647,7 +645,7 @@ def test_create_successor(self): self.assertEqual(episode_1.partial_rewards[agent_id][-1], 2.0) # Now create the successor. - episode_2 = episode_1.create_successor() + episode_2 = episode_1.cut() for agent_id, agent_eps in episode_2.agent_episodes.items(): if len(agent_eps.observations) > 0: @@ -704,8 +702,8 @@ def test_getters(self): actions=actions, rewards=rewards, infos=infos, - is_terminated=is_terminateds, - is_truncated=is_truncateds, + terminateds=is_terminateds, + truncateds=is_truncateds, extra_model_outputs=extra_model_outputs, ) @@ -791,9 +789,9 @@ def test_getters(self): # Test with initial observations only. episode_init_only = MultiAgentEpisode(agent_ids=agent_ids) - episode_init_only.add_initial_observation( - initial_observation=observations[0], - initial_info=infos[0], + episode_init_only.add_env_reset( + observation=observations[0], + infos=infos[0], ) # Get the last observation for agents and assert that its correct. last_observation = episode_init_only.get_observations() @@ -985,22 +983,22 @@ def test_getters(self): # Generate initial observation and info. obs, info = env.reset(seed=42) - episode_1.add_initial_observation( - initial_observation=obs, - initial_info=info, + episode_1.add_env_reset( + observations=obs, + infos=info, ) # Now, generate 100 samples. for i in range(100): action = {agent_id: i for agent_id in obs} - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode_1.add_timestep( - observation=obs, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={agent_id: {"extra": 10} for agent_id in action}, + obs, reward, terminated, truncated, info = env.step(action) + episode_1.add_env_step( + observations=obs, + actions=action, + rewards=reward, + infos=info, + terminateds=terminated, + truncateds=truncated, + extra_model_outputs={agent_id: {"extra": 10} for agent_id in action}, ) # First, receive the last rewards without considering buffered values. @@ -1644,7 +1642,7 @@ def test_concat_episode(self): size=100, truncate=False ) # Now, create a successor episode. - episode_2 = episode_1.create_successor() + episode_2 = episode_1.cut() # Generate 100 more samples from the environment and store it in the episode. episode_2, env = self._mock_multi_agent_records_from_env( size=100, episode=episode_2, env=env, init=False @@ -1900,7 +1898,7 @@ def test_len(self): # Assert that the length is indeed 100. self.assertEqual(len(episode), 100) # Now, build a successor. - successor = episode.create_successor() + successor = episode.cut() # Sample another 100 timesteps. successor, env = self._mock_multi_agent_records_from_env( episode=successor, env=env, init=False @@ -1935,7 +1933,7 @@ def test_to_sample_batch(self): # Now test that when creating a successor its sample batch will # contain the correct values. - successor = episode.create_successor() + successor = episode.cut() # Run 100 more timesteps for the successor. successor, env = self._mock_multi_agent_records_from_env( episode=successor, env=env, init=False @@ -2007,7 +2005,7 @@ def _mock_multi_agent_records_from_env( # We initialize the episode, if requested. if init: obs, info = env.reset(seed=seed) - episode.add_initial_observation(initial_observation=obs, initial_info=info) + episode.add_env_reset(observations=obs, infos=info) # In the other case wer need at least the last observations for the next # actions. else: @@ -2020,15 +2018,15 @@ def _mock_multi_agent_records_from_env( # Sample `size` many records. for i in range(env.t, env.t + size): action = {agent_id: i + 1 for agent_id in obs} - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode.add_timestep( - observation=obs, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={agent_id: {"extra": 10} for agent_id in action}, + obs, reward, terminated, truncated, info = env.step(action) + episode.add_env_step( + observations=obs, + actions=action, + rewards=reward, + infos=info, + terminateds=terminated, + truncateds=truncated, + extra_model_outputs={agent_id: {"extra": 10} for agent_id in action}, ) # Return both, epsiode and environment. @@ -2066,17 +2064,21 @@ def _mock_multi_agent_records(self): {"agent_2": {}, "agent_4": {}}, ] # Let no agent terminate or being truncated. - is_terminated = [ - {"__all__": False, "agent_1": False, "agent_3": False, "agent_4": False}, - {"__all__": False, "agent_2": False, "agent_4": False}, - ] - is_truncated = [ - {"__all__": False, "agent_1": False, "agent_3": False, "agent_4": False}, - {"__all__": False, "agent_2": False, "agent_4": False}, - ] + terminateds = { + "__all__": False, + "agent_1": False, + "agent_3": False, + "agent_4": False, + } + truncateds = { + "__all__": False, + "agent_1": False, + "agent_3": False, + "agent_4": False, + } # Return all observations. - return observations, actions, rewards, is_terminated, is_truncated, infos + return observations, actions, rewards, terminateds, truncateds, infos if __name__ == "__main__": diff --git a/rllib/env/tests/test_single_agent_gym_env_runner.py b/rllib/env/tests/test_single_agent_env_runner.py similarity index 54% rename from rllib/env/tests/test_single_agent_gym_env_runner.py rename to rllib/env/tests/test_single_agent_env_runner.py index 015c3837ff0f..743e885291e3 100644 --- a/rllib/env/tests/test_single_agent_gym_env_runner.py +++ b/rllib/env/tests/test_single_agent_env_runner.py @@ -2,10 +2,10 @@ import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.env.testing.single_agent_gym_env_runner import SingleAgentGymEnvRunner +from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner -class TestSingleAgentGymEnvRunner(unittest.TestCase): +class TestSingleAgentEnvRunner(unittest.TestCase): @classmethod def setUpClass(cls) -> None: ray.init() @@ -20,43 +20,41 @@ def test_sample(self): # Vectorize x2 and by default, rollout 64 timesteps per individual env. .rollouts(num_envs_per_worker=2, rollout_fragment_length=64) ) - env_runner = SingleAgentGymEnvRunner(config=config) + env_runner = SingleAgentEnvRunner(config=config) # Expect error if both num_timesteps and num_episodes given. self.assertRaises( - AssertionError, lambda: env_runner.sample(num_timesteps=10, num_episodes=10) + AssertionError, + lambda: env_runner.sample( + num_timesteps=10, num_episodes=10, random_actions=True + ), ) # Sample 10 episodes (5 per env) 100 times. for _ in range(100): - done_episodes, ongoing_episodes = env_runner.sample(num_episodes=10) - self.assertTrue(len(done_episodes + ongoing_episodes) == 10) + episodes = env_runner.sample(num_episodes=10, random_actions=True) + self.assertTrue(len(episodes) == 10) # Since we sampled complete episodes, there should be no ongoing episodes # being returned. - assert len(ongoing_episodes) == 0 - # Check, whether all done_episodes returned are indeed terminated. - self.assertTrue(all(e.is_done for e in done_episodes)) + self.assertTrue(all(e.is_done for e in episodes)) # Sample 10 timesteps (5 per env) 100 times. for _ in range(100): - done_episodes, ongoing_episodes = env_runner.sample(num_timesteps=10) - # Check, whether all done_episodes returned are indeed terminated. - self.assertTrue(all(e.is_done for e in done_episodes)) - # Check, whether all done_episodes returned are indeed terminated. - self.assertTrue(not any(e.is_done for e in ongoing_episodes)) + episodes = env_runner.sample(num_timesteps=10, random_actions=True) + # Check, whether the sum of lengths of all episodes returned is 20 + self.assertTrue(sum(len(e) for e in episodes) == 10) # Sample (by default setting: rollout_fragment_length=64) 10 times. for _ in range(100): - done_episodes, ongoing_episodes = env_runner.sample() - # Check, whether all done_episodes returned are indeed terminated. - self.assertTrue(all(e.is_done for e in done_episodes)) - # Check, whether all done_episodes returned are indeed terminated. - self.assertTrue(not any(e.is_done for e in ongoing_episodes)) + episodes = env_runner.sample(random_actions=True) + # Check, whether the sum of lengths of all episodes returned is 128 + # 2 (num_env_per_worker) * 64 (rollout_fragment_length). + self.assertTrue(sum(len(e) for e in episodes) == 128) def test_distributed_env_runner(self): """Tests, whether SingleAgentGymEnvRunner can be distributed.""" - remote_class = ray.remote(num_cpus=1, num_gpus=0)(SingleAgentGymEnvRunner) + remote_class = ray.remote(num_cpus=1, num_gpus=0)(SingleAgentEnvRunner) # Test with both parallelized sub-envs and w/o. remote_worker_envs = [False, True] @@ -78,19 +76,13 @@ def test_distributed_env_runner(self): for _ in range(config.num_rollout_workers) ] # Sample in parallel. - results = [a.sample.remote() for a in array] + results = [a.sample.remote(random_actions=True) for a in array] results = ray.get(results) # Loop over individual EnvRunner Actor's results and inspect each. - for result in results: - # SingleAgentGymEnvRunners return tuples: (completed eps, ongoing eps). - completed, ongoing = result - # Make sure all completed Episodes are indeed done. - self.assertTrue(all(e.is_done for e in completed)) - # Same for ongoing ones (make sure they are not done). - self.assertTrue(not any(e.is_done for e in ongoing)) + for episodes in results: # Assert length of all fragments is `rollout_fragment_length`. self.assertEqual( - sum(len(e) for e in completed + ongoing), + sum(len(e) for e in episodes), config.num_envs_per_worker * config.rollout_fragment_length, ) diff --git a/rllib/env/tests/test_single_agent_episode.py b/rllib/env/tests/test_single_agent_episode.py index 03389407a9e3..7ada5e61cc94 100644 --- a/rllib/env/tests/test_single_agent_episode.py +++ b/rllib/env/tests/test_single_agent_episode.py @@ -1,12 +1,13 @@ -import gymnasium as gym -import numpy as np +from collections import defaultdict +from typing import Any, Dict, Optional, SupportsFloat, Tuple import unittest +import gymnasium as gym from gymnasium.core import ActType, ObsType -from typing import Any, Dict, Optional, SupportsFloat, Tuple +import numpy as np -import ray from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.utils.test_utils import check # TODO (simon): Add to the tests `info` and `extra_model_outputs` # as soon as #39732 is merged. @@ -37,14 +38,6 @@ def step( class TestSingelAgentEpisode(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - ray.init() - - @classmethod - def tearDownClass(cls) -> None: - ray.shutdown() - def test_init(self): """Tests initialization of `SingleAgentEpisode`. @@ -64,45 +57,35 @@ def test_init(self): episode = SingleAgentEpisode(t_started=10) self.assertTrue(episode.t == episode.t_started == 10) - # Sample 100 values and initialize episode with observations and infos. - env = gym.make("CartPole-v1") - # Initialize containers. - observations = [] - rewards = [] - actions = [] - infos = [] - extra_model_outputs = [] - states = np.random.random(10) - - # Initialize observation and info. - init_obs, init_info = env.reset() - observations.append(init_obs) - infos.append(init_info) - # Run 100 samples. - for _ in range(100): - action = env.action_space.sample() - obs, reward, is_terminated, is_truncated, info = env.step(action) - observations.append(obs) - actions.append(action) - rewards.append(reward) - infos.append(info) - extra_model_outputs.append({"extra_1": np.random.random()}) - - # Build the episode. - episode = SingleAgentEpisode( - observations=observations, - actions=actions, - rewards=rewards, - infos=infos, - states=states, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_outputs=extra_model_outputs, - ) + episode = self._create_episode(num_data=100) # The starting point and count should now be at `len(observations) - 1`. - self.assertTrue(episode.t == episode.t_started == (len(observations) - 1)) + self.assertTrue(len(episode) == 100) + self.assertTrue(episode.t == 100) + self.assertTrue(episode.t_started == 0) - def test_add_initial_observation(self): + # Build the same episode, but with a 10 ts lookback buffer. + episode = self._create_episode(num_data=100, len_lookback_buffer=10) + # The lookback buffer now takes 10 ts and the length of the episode is only 90. + self.assertTrue(len(episode) == 90) + # `t_started` is 0 by default. + self.assertTrue(episode.t_started == 0) + self.assertTrue(episode.t == 90) + self.assertTrue(len(episode.rewards) == 90) + self.assertTrue(len(episode.rewards.data) == 100) + + # Build the same episode, but with a 10 ts lookback buffer AND a specific + # `t_started`. + episode = self._create_episode( + num_data=100, len_lookback_buffer=10, t_started=50 + ) + # The lookback buffer now takes 10 ts and the length of the episode is only 90. + self.assertTrue(len(episode) == 90) + self.assertTrue(episode.t_started == 50) + self.assertTrue(episode.t == 140) + self.assertTrue(len(episode.rewards) == 90) + self.assertTrue(len(episode.rewards.data) == 100) + + def test_add_env_reset(self): """Tests adding initial observations and infos. This test ensures that when initial observation and info are provided @@ -116,7 +99,7 @@ def test_add_initial_observation(self): # Add initial observations. obs, info = env.reset() - episode.add_initial_observation(initial_observation=obs, initial_info=info) + episode.add_env_reset(observation=obs, infos=info) # Assert that the observations are added to their list. self.assertTrue(len(episode.observations) == 1) @@ -125,7 +108,7 @@ def test_add_initial_observation(self): # Assert that the timesteps are still at zero as we have not stepped, yet. self.assertTrue(episode.t == episode.t_started == 0) - def test_add_timestep(self): + def test_add_env_step(self): """Tests if adding timestep data to a `SingleAgentEpisode` works. Adding timestep data is the central part of collecting episode @@ -138,22 +121,23 @@ def test_add_timestep(self): # Set the random seed (otherwise the episode will terminate at # different points in each test run). obs, info = env.reset(seed=0) - episode.add_initial_observation(initial_observation=obs, initial_info=info) + episode.add_env_reset(observation=obs, infos=info) # Sample 100 timesteps and add them to the episode. + terminated = truncated = False for i in range(100): action = env.action_space.sample() - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode.add_timestep( + obs, reward, terminated, truncated, info = env.step(action) + episode.add_env_step( observation=obs, action=action, reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)}, + infos=info, + terminated=terminated, + truncated=truncated, + extra_model_outputs={"extra": np.random.random(1)}, ) - if is_terminated or is_truncated: + if terminated or truncated: break # Assert that the episode timestep is at 100. @@ -169,11 +153,11 @@ def test_add_timestep(self): == i + 1 ) # Assert that the flags are set correctly. - self.assertTrue(episode.is_terminated == is_terminated) - self.assertTrue(episode.is_truncated == is_truncated) - self.assertTrue(episode.is_done == is_terminated or is_truncated) + self.assertTrue(episode.is_terminated == terminated) + self.assertTrue(episode.is_truncated == truncated) + self.assertTrue(episode.is_done == terminated or truncated) - def test_create_successor(self): + def test_cut(self): """Tests creation of a scucessor of a `SingleAgentEpisode`. This test makes sure that when creating a successor the successor's @@ -188,28 +172,26 @@ def test_create_successor(self): env = TestEnv() # Add initial observation. init_obs, init_info = env.reset() - episode_1.add_initial_observation( - initial_observation=init_obs, initial_info=init_info - ) + episode_1.add_env_reset(observation=init_obs, infos=init_info) # Sample 100 steps. for i in range(100): action = i - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode_1.add_timestep( + obs, reward, terminated, truncated, info = env.step(action) + episode_1.add_env_step( observation=obs, action=action, reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)}, + infos=info, + terminated=terminated, + truncated=truncated, + extra_model_outputs={"extra": np.random.random(1)}, ) # Assert that the episode has indeed 100 timesteps. self.assertTrue(episode_1.t == 100) # Create a successor. - episode_2 = episode_1.create_successor() + episode_2 = episode_1.cut() # Assert that it has the same id. self.assertEqual(episode_1.id_, episode_2.id_) # Assert that the timestep starts at the end of the last episode. @@ -222,19 +204,126 @@ def test_create_successor(self): # Test immutability. action = 100 - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode_2.add_timestep( + obs, reward, terminated, truncated, info = env.step(action) + episode_2.add_env_step( observation=obs, action=action, reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)}, + infos=info, + terminated=terminated, + truncated=truncated, + extra_model_outputs={"extra": np.random.random(1)}, ) # Assert that this does not change also the predecessor's data. self.assertFalse(len(episode_1.observations) == len(episode_2.observations)) + def test_slices(self): + # TEST #1: even split (50/50) + episode = self._create_episode(100) + self.assertTrue(episode.t == 100 and episode.t_started == 0) + + # Convert to numpy before splitting. + episode.finalize() + + # Create two 50/50 episode chunks. + e1 = episode[:50] + self.assertTrue(e1.is_finalized) + e2 = episode.slice(slice(50, None)) + self.assertTrue(e2.is_finalized) + + # Make sure, `e1` and `e2` make sense. + self.assertTrue(len(e1) == 50) + self.assertTrue(len(e2) == 50) + self.assertTrue(e1.id_ == e2.id_) + self.assertTrue(e1.t_started == 0) + self.assertTrue(e1.t == 50) + self.assertTrue(e2.t_started == 50) + self.assertTrue(e2.t == 100) + # Make sure the chunks are not identical, but last obs of `e1` matches + # last obs of `e2`. + check(e1.get_observations(-1), e2.get_observations(0)) + check(e1.observations[4], e2.observations[4], false=True) + check(e1.observations[10], e2.observations[10], false=True) + + # TEST #2: Uneven split (33/66). + episode = self._create_episode(99) + self.assertTrue(episode.t == 99 and episode.t_started == 0) + + # Convert to numpy before splitting. + episode.finalize() + + # Create two 50/50 episode chunks. + e1 = episode.slice(slice(None, 33)) + self.assertTrue(e1.is_finalized) + e2 = episode[33:] + self.assertTrue(e2.is_finalized) + + # Make sure, `e1` and `e2` chunk make sense. + self.assertTrue(len(e1) == 33) + self.assertTrue(len(e2) == 66) + self.assertTrue(e1.id_ == e2.id_) + self.assertTrue(e1.t_started == 0) + self.assertTrue(e1.t == 33) + self.assertTrue(e2.t_started == 33) + self.assertTrue(e2.t == 99) + # Make sure the chunks are not identical, but last obs of `e1` matches + # last obs of `e2`. + check(e1.get_observations(-1), e2.get_observations(0)) + check(e1.observations[4], e2.observations[4], false=True) + check(e1.observations[10], e2.observations[10], false=True) + + # TEST #3: Split with lookback buffer (buffer=10, split=20/30). + episode = self._create_episode( + num_data=60, t_started=15, len_lookback_buffer=10 + ) + self.assertTrue(episode.t == 65 and episode.t_started == 15) + + # Convert to numpy before splitting. + episode.finalize() + + # Create two 20/30 episode chunks. + e1 = episode.slice(slice(None, 20)) + self.assertTrue(e1.is_finalized) + e2 = episode[20:] + self.assertTrue(e2.is_finalized) + + # Make sure, `e1` and `e2` make sense. + self.assertTrue(len(e1) == 20) + self.assertTrue(len(e2) == 30) + self.assertTrue(e1.id_ == e2.id_) + self.assertTrue(e1.t_started == 15) + self.assertTrue(e1.t == 35) + self.assertTrue(e2.t_started == 35) + self.assertTrue(e2.t == 65) + # Make sure the chunks are not identical, but last obs of `e1` matches + # last obs of `e2`. + check(e1.get_observations(-1), e2.get_observations(0)) + check(e1.observations[5], e2.observations[5], false=True) + check(e1.observations[11], e2.observations[11], false=True) + # Make sure the lookback buffers of both chunks are still working. + check( + e1.get_observations(-1, neg_indices_left_of_zero=True), + episode.observations.data[episode._len_lookback_buffer - 1], + ) + check( + e1.get_actions(-1, neg_indices_left_of_zero=True), + episode.actions.data[episode._len_lookback_buffer - 1], + ) + check( + e2.get_observations([-5, -2], neg_indices_left_of_zero=True), + [ + episode.observations.data[20 + episode._len_lookback_buffer - 5], + episode.observations.data[20 + episode._len_lookback_buffer - 2], + ], + ) + check( + e2.get_rewards([-5, -2], neg_indices_left_of_zero=True), + [ + episode.rewards.data[20 + episode._len_lookback_buffer - 5], + episode.rewards.data[20 + episode._len_lookback_buffer - 2], + ], + ) + def test_concat_episode(self): """Tests if concatenation of two `SingleAgentEpisode`s works. @@ -248,38 +337,36 @@ def test_concat_episode(self): env = TestEnv() init_obs, init_info = env.reset() episode_1 = SingleAgentEpisode() - episode_1.add_initial_observation( - initial_observation=init_obs, initial_info=init_info - ) + episode_1.add_env_reset(observation=init_obs, infos=init_info) # Sample 100 timesteps. for i in range(100): action = i - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode_1.add_timestep( + obs, reward, terminated, truncated, info = env.step(action) + episode_1.add_env_step( observation=obs, action=action, reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)}, + infos=info, + terminated=terminated, + truncated=truncated, + extra_model_outputs={"extra": np.random.random(1)}, ) # Create a successor. - episode_2 = episode_1.create_successor() + episode_2 = episode_1.cut() # Now, sample 100 more timesteps. for i in range(100, 200): action = i - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode_2.add_timestep( + obs, reward, terminated, truncated, info = env.step(action) + episode_2.add_env_step( observation=obs, action=action, reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)}, + infos=info, + terminated=terminated, + truncated=truncated, + extra_model_outputs={"extra": np.random.random(1)}, ) # Assert that the second episode's `t_started` is at the first episode's @@ -309,7 +396,7 @@ def test_concat_episode(self): # Reset `is_terminated`. episode_1.is_terminated = False - # Concate the episodes. + # Concatenate the episodes. episode_1.concat_episode(episode_2) # Assert that the concatenated episode start at `t_started=0` @@ -331,118 +418,39 @@ def test_concat_episode(self): # self.assertNotEqual(id(episode_2.observations[5]), # id(episode_1.observations[105])) - def test_get_and_from_state(self): - """Tests, if a `SingleAgentEpisode` can be reconstructed form state. + def _create_episode(self, num_data, t_started=None, len_lookback_buffer=0): + # Sample 100 values and initialize episode with observations and infos. + env = gym.make("CartPole-v1") + # Initialize containers. + observations = [] + rewards = [] + actions = [] + infos = [] + extra_model_outputs = defaultdict(list) - This test constructs an episode, stores it to its dictionary state and - recreates a new episode form this state. Thereby it ensures that all - atttributes are indeed identical to the primer episode and the data is - complete. - """ - # Create an empty episode. - episode = SingleAgentEpisode() - # Create an environment. - env = TestEnv() - # Add initial observation. + # Initialize observation and info. init_obs, init_info = env.reset() - episode.add_initial_observation( - initial_observation=init_obs, initial_info=init_info - ) - # Sample 100 steps. - for i in range(100): - action = i - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode.add_timestep( - observation=obs, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)}, - ) - - # Get the state and reproduce it from state. - state = episode.get_state() - episode_reproduced = SingleAgentEpisode.from_state(state) - - # Assert that the data is complete. - self.assertEqual(episode.id_, episode_reproduced.id_) - self.assertEqual(episode.t, episode_reproduced.t) - self.assertEqual(episode.t_started, episode_reproduced.t_started) - self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) - self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) - self.assertListEqual(episode.observations, episode_reproduced.observations) - self.assertListEqual(episode.actions, episode_reproduced.actions) - self.assertListEqual(episode.rewards, episode_reproduced.rewards) - self.assertListEqual(episode.infos, episode_reproduced.infos) - self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) - self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) - self.assertEqual(episode.states, episode_reproduced.states) - self.assertListEqual(episode.render_images, episode_reproduced.render_images) - self.assertDictEqual( - episode.extra_model_outputs, episode_reproduced.extra_model_outputs - ) - - # Assert that reconstruction breaks, if the data is not complete. - state[1][1].pop() - with self.assertRaises(AssertionError): - episode_reproduced = SingleAgentEpisode.from_state(state) - - def test_to_and_from_sample_batch(self): - """Tests if a `SingelAgentEpisode` can be reconstructed from a `SampleBatch`. - - This tests converst an episode to a `SampleBatch` and reconstructs the - episode then from this sample batch. It is then tested, if all data is - complete. - Note that `extra_model_outputs` are defined by the user and as the format - in the episode from which a `SampleBatch` was created is unknown this - reconstruction would only work, if the user does take care of it (as a - counter example just rempve the index [0] from the `extra_model_output`). - """ - # Create an empty episode. - episode = SingleAgentEpisode() - # Create an environment. - env = TestEnv() - # Add initial observation. - init_obs, init_obs = env.reset() - episode.add_initial_observation( - initial_observation=init_obs, initial_info=init_obs - ) - # Sample 100 steps. - for i in range(100): - action = i - obs, reward, is_terminated, is_truncated, info = env.step(action) - episode.add_timestep( - observation=obs, - action=action, - reward=reward, - info=info, - is_terminated=is_terminated, - is_truncated=is_truncated, - extra_model_output={"extra": np.random.random(1)[0]}, - ) + observations.append(init_obs) + infos.append(init_info) + # Run n samples. + for _ in range(num_data): + action = env.action_space.sample() + obs, reward, _, _, info = env.step(action) + observations.append(obs) + actions.append(action) + rewards.append(reward) + infos.append(info) + extra_model_outputs["extra_1"].append(np.random.random()) + extra_model_outputs["state_out"].append(np.random.random()) - # Create `SampleBatch`. - batch = episode.to_sample_batch() - # Reproduce form `SampleBatch`. - episode_reproduced = SingleAgentEpisode.from_sample_batch(batch) - # Assert that the data is complete. - self.assertEqual(episode.id_, episode_reproduced.id_) - self.assertEqual(episode.t, episode_reproduced.t) - self.assertEqual(episode.t_started, episode_reproduced.t_started) - self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) - self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) - self.assertListEqual(episode.observations, episode_reproduced.observations) - self.assertListEqual(episode.actions, episode_reproduced.actions) - self.assertListEqual(episode.rewards, episode_reproduced.rewards) - self.assertEqual(episode.infos, episode_reproduced.infos) - self.assertEqual(episode.is_terminated, episode_reproduced.is_terminated) - self.assertEqual(episode.is_truncated, episode_reproduced.is_truncated) - self.assertEqual(episode.states, episode_reproduced.states) - self.assertListEqual(episode.render_images, episode_reproduced.render_images) - self.assertDictEqual( - episode.extra_model_outputs, episode_reproduced.extra_model_outputs + return SingleAgentEpisode( + observations=observations, + infos=infos, + actions=actions, + rewards=rewards, + extra_model_outputs=extra_model_outputs, + t_started=t_started, + len_lookback_buffer=len_lookback_buffer, ) diff --git a/rllib/env/utils.py b/rllib/env/utils.py index a9deba95548d..30642462bbc1 100644 --- a/rllib/env/utils.py +++ b/rllib/env/utils.py @@ -1,7 +1,9 @@ import logging -from typing import Type, Union +from typing import List, Optional, Type, Union import gymnasium as gym +import numpy as np +import tree # pip install dm_tree from ray.rllib.env.env_context import EnvContext from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -14,6 +16,12 @@ EnvError, ) from ray.rllib.utils.gym import check_old_gym_env +from ray.rllib.utils.numpy import one_hot, one_hot_multidiscrete +from ray.rllib.utils.spaces.space_utils import ( + batch, + get_dummy_batch_for_space, + get_base_struct_from_space, +) from ray.util import log_once from ray.util.annotations import PublicAPI @@ -169,3 +177,328 @@ def _gym_env_creator( raise EnvError(ERR_MSG_INVALID_ENV_DESCRIPTOR.format(env_descriptor)) return env + + +class BufferWithInfiniteLookback: + def __init__( + self, + data: Optional[Union[List, np.ndarray]] = None, + lookback: int = 0, + space: Optional[gym.Space] = None, + ): + self.data = data if data is not None else [] + self.lookback = lookback + self.finalized = not isinstance(self.data, list) + self._final_len = None + self.space = space + self.space_struct = get_base_struct_from_space(self.space) + + def append(self, item) -> None: + """Appends the given item to the end of this buffer.""" + if self.finalized: + raise RuntimeError(f"Cannot `append` to a finalized {type(self).__name__}.") + self.data.append(item) + + def extend(self, items): + """Appends all items in `items` to the end of this buffer.""" + if self.finalized: + raise RuntimeError(f"Cannot `extend` a finalized {type(self).__name__}.") + for item in items: + self.append(item) + + def pop(self, index: int = -1): + """Removes the item at `index` from this buffer.""" + if self.finalized: + raise RuntimeError(f"Cannot `pop` from a finalized {type(self).__name__}.") + return self.data.pop(index) + + def finalize(self): + """Finalizes this buffer by converting internal data lists into numpy arrays. + + Thereby, if the individual items in the list are complex (nested 2) + """ + if not self.finalized: + self._final_len = len(self.data) - self.lookback + self.data = batch(self.data) + self.finalized = True + + def get( + self, + indices: Optional[Union[int, slice, List[int]]] = None, + neg_indices_left_of_zero: bool = False, + fill: Optional[float] = None, + one_hot_discrete: bool = False, + ): + """Returns data, based on the given args, from this buffer. + + Args: + indices: A single int is interpreted as an index, from which to return the + individual data stored at this index. + A list of ints is interpreted as a list of indices from which to gather + individual data in a batch of size len(indices). + A slice object is interpreted as a range of data to be returned. + Thereby, negative indices by default are interpreted as "before the end" + unless the `neg_indices_left_of_zero=True` option is used, in which case + negative indices are interpreted as "before ts=0", meaning going back + into the lookback buffer. + neg_indices_left_of_zero: If True, negative values in `indices` are + interpreted as "before ts=0", meaning going back into the lookback + buffer. For example, an buffer with data [4, 5, 6, 7, 8, 9], + where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will + respond to `get(-1, neg_indices_left_of_zero=True)` with `6` and to + `get(slice(-2, 1), neg_indices_left_of_zero=True)` with `[5, 6, 7]`. + fill: An optional float value to use for filling up the returned results at + the boundaries. This filling only happens if the requested index range's + start/stop boundaries exceed the buffer's boundaries (including the + lookback buffer on the left side). This comes in very handy, if users + don't want to worry about reaching such boundaries and want to zero-pad. + For example, a buffer with data [10, 11, 12, 13, 14] and lookback + buffer size of 2 (meaning `10` and `11` are part of the lookback buffer) + will respond to `get(slice(-7, -2), fill=0.0)` + with `[0.0, 0.0, 10, 11, 12]`. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) space + that are Discrete or MultiDiscrete. Note that if `fill=0` and the + requested `indices` are out of the range of our data, the returned + one-hot vectors will actually be zero-hot (all slots zero). + """ + if fill is not None and self.space is None: + raise ValueError( + f"Cannot use `fill` argument in `{type(self).__name__}.get()` if a " + "gym.Space was NOT provided during construction!" + ) + + if indices is None: + data = self._get_all_data(one_hot_discrete=one_hot_discrete) + elif isinstance(indices, slice): + data = self._get_slice( + indices, + fill=fill, + neg_indices_left_of_zero=neg_indices_left_of_zero, + one_hot_discrete=one_hot_discrete, + ) + elif isinstance(indices, list): + data = [ + self._get_int_index( + idx, + fill=fill, + neg_indices_left_of_zero=neg_indices_left_of_zero, + one_hot_discrete=one_hot_discrete, + ) + for idx in indices + ] + if self.finalized: + data = batch(data) + else: + assert isinstance(indices, int) + data = self._get_int_index( + indices, + fill=fill, + neg_indices_left_of_zero=neg_indices_left_of_zero, + one_hot_discrete=one_hot_discrete, + ) + + return data + + def __getitem__(self, item): + """Support squared bracket syntax, e.g. buffer[:5].""" + return self.get(item) + + def __len__(self): + """Return the length of our data, excluding the lookback buffer.""" + if self._final_len is not None: + assert self.finalized + return self._final_len + return len(self.data) - self.lookback + + def _get_all_data(self, one_hot_discrete=False): + data = self[:] + if one_hot_discrete: + data = self._one_hot(data, space_struct=self.space_struct) + return data + + def _get_slice( + self, + slice_, + fill=None, + neg_indices_left_of_zero=False, + one_hot_discrete=False, + ): + len_self_plus_lookback = len(self) + self.lookback + fill_left_count = fill_right_count = 0 + + # Re-interpret slice bounds as absolute positions (>=0) within our + # internal data. + start = slice_.start + stop = slice_.stop + + # Start is None -> Exclude lookback buffer. + if start is None: + start = self.lookback + # Start is negative. + elif start < 0: + # `neg_indices_left_of_zero=True` -> User wants to index into the lookback + # range. + if neg_indices_left_of_zero: + start = self.lookback + start + # Interpret index as counting "from end". + else: + start = len_self_plus_lookback + start + # Start is 0 or positive -> timestep right after lookback is interpreted as 0. + else: + start = self.lookback + start + + # Stop is None -> Set stop to very last index + 1 of our internal data. + if stop is None: + stop = len_self_plus_lookback + # Stop is negative. + elif stop < 0: + # `neg_indices_left_of_zero=True` -> User wants to index into the lookback + # range. Set to 0 (beginning of lookback buffer) if result is a negative + # index. + if neg_indices_left_of_zero: + stop = self.lookback + stop + # Interpret index as counting "from end". Set to 0 (beginning of actual + # episode) if result is a negative index. + else: + stop = len_self_plus_lookback + stop + # Stop is positive -> Add lookback range to it. + else: + stop = self.lookback + stop + + # Both start and stop are on left side. + if start < 0 and stop < 0: + fill_left_count = abs(start - stop) + fill_right_count = 0 + start = stop = 0 + # Both start and stop are on right side. + elif start >= len_self_plus_lookback and stop >= len_self_plus_lookback: + fill_right_count = abs(start - stop) + fill_left_count = 0 + start = stop = len_self_plus_lookback + # Set to 0 (beginning of actual episode) if result is a negative index. + elif start < 0: + fill_left_count = -start + start = 0 + elif stop >= len_self_plus_lookback: + fill_right_count = stop - len_self_plus_lookback + stop = len_self_plus_lookback + + assert start >= 0 and stop >= 0, (start, stop) + assert start <= len_self_plus_lookback and stop <= len_self_plus_lookback, ( + start, + stop, + ) + slice_ = slice(start, stop, slice_.step) + + # Perform the actual slice. + if self.finalized: + data_slice = tree.map_structure(lambda s: s[slice_], self.data) + else: + data_slice = self.data[slice_] + + if one_hot_discrete: + data_slice = self._one_hot(data_slice, space_struct=self.space_struct) + + # Data is shorter than the range requested -> Fill the rest with `fill` data. + if fill is not None and (fill_right_count > 0 or fill_left_count > 0): + if self.finalized: + if fill_left_count: + fill_batch = get_dummy_batch_for_space( + self.space, + fill_value=fill, + batch_size=fill_left_count, + one_hot_discrete=one_hot_discrete, + ) + data_slice = tree.map_structure( + lambda s0, s: np.concatenate([s0, s]), fill_batch, data_slice + ) + if fill_right_count: + fill_batch = get_dummy_batch_for_space( + self.space, + fill_value=fill, + batch_size=fill_right_count, + one_hot_discrete=one_hot_discrete, + ) + data_slice = tree.map_structure( + lambda s0, s: np.concatenate([s, s0]), fill_batch, data_slice + ) + + else: + fill_batch = [ + get_dummy_batch_for_space( + self.space, + fill_value=fill, + batch_size=0, + one_hot_discrete=one_hot_discrete, + ) + ] + data_slice = ( + fill_batch * fill_left_count + + data_slice + + fill_batch * fill_right_count + ) + + return data_slice + + def _get_int_index( + self, + idx: int, + fill=None, + neg_indices_left_of_zero=False, + one_hot_discrete=False, + ): + # If index >= 0 -> Ignore lookback buffer. + # Otherwise, include lookback buffer. + if idx >= 0 or neg_indices_left_of_zero: + idx = self.lookback + idx + # Negative indices mean: Go to left into lookback buffer starting from idx=0. + # But if we pass the lookback buffer, the index should be invalid and we will + # have to fill, if required. Invalidate the index by setting it to one larger + # than max. + if neg_indices_left_of_zero and idx < 0: + idx = len(self) + self.lookback + + try: + if self.finalized: + data = tree.map_structure(lambda s: s[idx], self.data) + else: + data = self.data[idx] + # Out of range index -> If `fill`, use a fill dummy (B=0), if not, error out. + except IndexError as e: + if fill is not None: + return get_dummy_batch_for_space( + self.space, + fill_value=fill, + batch_size=0, + one_hot_discrete=one_hot_discrete, + ) + else: + raise e + + # Convert discrete/multi-discrete components to one-hot vectors, if required. + if one_hot_discrete: + data = self._one_hot(data, self.space_struct) + return data + + def _one_hot(self, data, space_struct): + if space_struct is None: + raise ValueError( + f"Cannot `one_hot` data in `{type(self).__name__}` if a " + "gym.Space was NOT provided during construction!" + ) + + def _convert(dat_, space): + if isinstance(space, gym.spaces.Discrete): + return one_hot(dat_, depth=space.n) + elif isinstance(space, gym.spaces.MultiDiscrete): + return one_hot_multidiscrete(dat_, depths=space.nvec) + return dat_ + + if isinstance(data, list): + data = [ + tree.map_structure(_convert, dslice, space_struct) for dslice in data + ] + else: + data = tree.map_structure(_convert, data, space_struct) + return data diff --git a/rllib/evaluation/postprocessing_v2.py b/rllib/evaluation/postprocessing_v2.py index 9fae6c1ce325..a50da0718e5d 100644 --- a/rllib/evaluation/postprocessing_v2.py +++ b/rllib/evaluation/postprocessing_v2.py @@ -6,13 +6,13 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core.models.base import STATE_IN from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.evaluation.postprocessing import discount_cumsum from ray.rllib.policy.sample_batch import concat_samples, SampleBatch from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.typing import TensorType @@ -43,10 +43,10 @@ def postprocess_episodes_to_sample_batch( # a list. if isinstance(episode_or_list, list): for episode in episode_or_list: - batches.append(episode.to_sample_batch()) + batches.append(episode.get_sample_batch()) # During exploration we have an episode. else: - batches.append(episode_or_list.to_sample_batch()) + batches.append(episode_or_list.get_sample_batch()) batch = concat_samples(batches) # TODO (sven): During evalaution we do not have infos at all. diff --git a/rllib/tuned_examples/ppo/memory-leak-test-ppo-new-stack.py b/rllib/tuned_examples/ppo/memory-leak-test-ppo-new-stack.py new file mode 100644 index 000000000000..f049692edfaf --- /dev/null +++ b/rllib/tuned_examples/ppo/memory-leak-test-ppo-new-stack.py @@ -0,0 +1,17 @@ +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner +from ray.rllib.examples.env.random_env import RandomLargeObsSpaceEnv + + +config = ( + PPOConfig() + .experimental(_enable_new_api_stack=True) + # Switch off np.random, which is known to have memory leaks. + .environment(RandomLargeObsSpaceEnv, env_config={"static_samples": True}) + .rollouts( + env_runner_cls=SingleAgentEnvRunner, + num_rollout_workers=4, + num_envs_per_worker=5, + ) + .training(train_batch_size=500, sgd_minibatch_size=256, num_sgd_iter=5) +) diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index 371885f3744b..9f040c8a0c28 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -235,7 +235,8 @@ def flatten_inputs_to_1d_tensor( Args: inputs: The inputs to be flattened. - spaces_struct: The structure of the spaces that behind the input + spaces_struct: The (possibly nested) structure of the spaces that `inputs` + belongs to. time_axis: Whether all inputs have a time-axis (after the batch axis). If True, will keep not only the batch axis (0th), but the time axis (1st) as-is and flatten everything from the 2nd axis up. @@ -501,10 +502,7 @@ def one_hot( ) shape = x.shape - # Python 2.7 compatibility, (*shape, depth) is not allowed. - shape_list = list(shape[:]) - shape_list.append(depth) - out = np.ones(shape_list) * off_value + out = np.ones(shape=(*shape, depth)) * off_value indices = [] for i in range(x.ndim): tiles = [1] * x.ndim @@ -520,6 +518,22 @@ def one_hot( return out.astype(dtype) +@PublicAPI +def one_hot_multidiscrete(x, depths=List[int]): + # Handle torch arrays properly. + if torch and isinstance(x, torch.Tensor): + x = x.numpy() + + shape = x.shape + return np.concatenate( + [ + one_hot(x[i] if len(shape) == 1 else x[:, i], depth=n).astype(np.float32) + for i, n in enumerate(depths) + ], + axis=-1, + ) + + @PublicAPI def relu(x: np.ndarray, alpha: float = 0.0) -> np.ndarray: """Implementation of the leaky ReLU function. diff --git a/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py index 95662eb14cbf..2f9dd1b20e10 100644 --- a/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_episode_replay_buffer.py @@ -13,11 +13,11 @@ def _get_episode(episode_len=None, id_=None): eps = SingleAgentEpisode(id_=id_, observations=[0.0], infos=[{}]) ts = np.random.randint(1, 200) if episode_len is None else episode_len for t in range(ts): - eps.add_timestep( + eps.add_env_step( observation=float(t + 1), action=int(t), reward=0.1 * (t + 1), - info={}, + infos={}, ) eps.is_terminated = np.random.random() > 0.5 eps.is_truncated = False if eps.is_terminated else np.random.random() > 0.8 diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index c3d621189d8f..426abb13c012 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -100,11 +100,13 @@ def get_dummy_batch_for_space( fill_value: Union[float, int, str] = 0.0, time_size: Optional[int] = None, time_major: bool = False, + one_hot_discrete: bool = False, ) -> np.ndarray: """Returns batched dummy data (using `batch_size`) for the given `space`. Note: The returned batch will not pass a `space.contains(batch)` test - as an additional batch dimension has to be added as dim=0. + as an additional batch dimension has to be added at axis 0, unless `batch_size` is + set to 0. Args: space: The space to get a dummy batch for. @@ -114,23 +116,45 @@ def get_dummy_batch_for_space( fill_value: The value to fill the batch with or "random" for random values. time_size: If not None, add an optional time axis - of `time_size` size to the returned batch. + of `time_size` size to the returned batch. This time axis might either + be inserted at axis=1 (default) or axis=0, if `time_major` is True. time_major: If True AND `time_size` is not None, return batch as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size` if None, ignore this setting and return [B x ...]. + one_hot_discrete: If True, will return one-hot vectors (instead of + int-values) for those sub-components of a (possibly complex) `space` + that are Discrete or MultiDiscrete. Note that in case `fill_value` is 0.0, + this will result in zero-hot vectors (where all slots have a value of 0.0). Returns: The dummy batch of size `bqtch_size` matching the given space. """ # Complex spaces. Perform recursive calls of this function. - if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, dict, tuple)): + base_struct = space + if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)): + base_struct = get_base_struct_from_space(space) return tree.map_structure( - lambda s: get_dummy_batch_for_space(s, batch_size, fill_value), - get_base_struct_from_space(space), + lambda s: get_dummy_batch_for_space( + space=s, + batch_size=batch_size, + fill_value=fill_value, + time_size=time_size, + time_major=time_major, + one_hot_discrete=one_hot_discrete, + ), + base_struct, ) + + if one_hot_discrete: + if isinstance(space, gym.spaces.Discrete): + space = gym.spaces.Box(0.0, 1.0, (space.n,), np.float32) + elif isinstance(space, gym.spaces.MultiDiscrete): + space = gym.spaces.Box(0.0, 1.0, (np.sum(space.nvec),), np.float32) + # Primivite spaces: Box, Discrete, MultiDiscrete. # Random values: Use gym's sample() method. - elif fill_value == "random": + if fill_value == "random": if time_size is not None: assert batch_size > 0 and time_size > 0 if time_major: diff --git a/rllib/utils/spaces/tests/test_space_utils.py b/rllib/utils/spaces/tests/test_space_utils.py index c0200b6870dd..69d007ed12d0 100644 --- a/rllib/utils/spaces/tests/test_space_utils.py +++ b/rllib/utils/spaces/tests/test_space_utils.py @@ -76,6 +76,12 @@ def test_unsquash_action(self): def test_batch_and_unbatch(self): """Tests the two utility functions `batch` and `unbatch`.""" + # Test, whether simple structs are batch/unbatch'able as well. + # B=8 + simple_struct = [0, 1, 2, 3, 4, 5, 6, 7] + simple_struct_batched = batch(simple_struct) + check(unbatch(simple_struct_batched), simple_struct) + # Create a complex struct of individual batches (B=2). complex_struct = { "a": ( diff --git a/rllib/utils/tests/run_memory_leak_tests.py b/rllib/utils/tests/run_memory_leak_tests.py index 5b770a088410..4fc509fd7c88 100644 --- a/rllib/utils/tests/run_memory_leak_tests.py +++ b/rllib/utils/tests/run_memory_leak_tests.py @@ -93,6 +93,8 @@ # For python files, need to make sure, we only deliver the module name into the # `load_experiments_from_file` function (everything from "/ray/rllib" on). if file.endswith(".py"): + if file.endswith("__init__.py"): # weird CI learning test (BAZEL) case + continue experiments = load_experiments_from_file(file, SupportedFileType.python) else: experiments = load_experiments_from_file(file, SupportedFileType.yaml) @@ -122,11 +124,12 @@ leaking = True try: ray.init(num_cpus=5, local_mode=args.local_mode) - algo = get_trainable_cls(experiment["run"])(experiment["config"]) - results = check_memory_leaks( - algo, - to_check=set(args.to_check), - ) + if isinstance(experiment["run"], str): + algo_cls = get_trainable_cls(experiment["run"]) + else: + algo_cls = get_trainable_cls(experiment["run"].__name__) + algo = algo_cls(experiment["config"]) + results = check_memory_leaks(algo, to_check=set(args.to_check)) if not results: leaking = False finally: