diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index 4689c9261d8a..6ebc3d097018 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta): @abstractmethod def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, - policy_id: PolicyID, init_obs: TensorType) -> None: + policy_id: PolicyID, t: int, + init_obs: TensorType) -> None: """Adds an initial obs (after reset) to this collector. Since the very first observation in an environment is collected w/o @@ -48,6 +49,8 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, values for. env_id (EnvID): The environment index (in a vectorized setup). policy_id (PolicyID): Unique id for policy controlling the agent. + t (int): The time step (episode length - 1). The initial obs has + ts=-1(!), then an action/reward/next-obs at t=0, etc.. init_obs (TensorType): Initial observation (after env.reset()). Examples: @@ -172,9 +175,10 @@ def postprocess_episode(self, MultiAgentBatch. Used for batch_mode=`complete_episodes`. Returns: - Any: An ID that can be used in `build_multi_agent_batch` to - retrieve the samples that have been postprocessed as a - ready-built MultiAgentBatch. + Optional[MultiAgentBatch]: If `build` is True, the + SampleBatch or MultiAgentBatch built from `episode` (either + just from that episde or from the `_PolicyCollectorGroup` + in the `episode.batch_builder` property). """ raise NotImplementedError diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index b7a9e5244240..7db62c5329c9 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -52,17 +52,19 @@ def __init__(self, shift_before: int = 0): # each time a (non-initial!) observation is added. self.count = 0 - def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, - env_id: EnvID, init_obs: TensorType, + def add_init_obs(self, episode_id: EpisodeID, agent_index: int, + env_id: EnvID, t: int, init_obs: TensorType, view_requirements: Dict[str, ViewRequirement]) -> None: """Adds an initial observation (after reset) to the Agent's trajectory. Args: episode_id (EpisodeID): Unique ID for the episode we are adding the initial observation for. - agent_id (AgentID): Unique ID for the agent we are adding the - initial observation for. + agent_index (int): Unique int index (starting from 0) for the agent + within its episode. env_id (EnvID): The environment index (in a vectorized setup). + t (int): The time step (episode length - 1). The initial obs has + ts=-1(!), then an action/reward/next-obs at t=0, etc.. init_obs (TensorType): The initial observation tensor (after `env.reset()`). view_requirements (Dict[str, ViewRequirements]) @@ -72,10 +74,15 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, single_row={ SampleBatch.OBS: init_obs, SampleBatch.EPS_ID: episode_id, - SampleBatch.AGENT_INDEX: agent_id, + SampleBatch.AGENT_INDEX: agent_index, "env_id": env_id, + "t": t, }) self.buffers[SampleBatch.OBS].append(init_obs) + self.buffers[SampleBatch.EPS_ID].append(episode_id) + self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) + self.buffers["env_id"].append(env_id) + self.buffers["t"].append(t) def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ None: @@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ continue # OBS are already shifted by -1 (the initial obs starts one ts # before all other data columns). - shift = view_req.shift - \ + shift = view_req.data_rel_pos - \ (1 if data_col == SampleBatch.OBS else 0) if data_col not in np_data: np_data[data_col] = to_float_np_array(self.buffers[data_col]) @@ -187,7 +194,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: for col, data in single_row.items(): if col in self.buffers: continue - shift = self.shift_before - (1 if col == SampleBatch.OBS else 0) + shift = self.shift_before - (1 if col in [ + SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, + "env_id", "t" + ] else 0) # Python primitive or dict (e.g. INFOs). if isinstance(data, (int, float, bool, str, dict)): self.buffers[col] = [0 for _ in range(shift)] @@ -360,7 +370,7 @@ def episode_step(self, episode_id: EpisodeID) -> None: @override(_SampleCollector) def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, - env_id: EnvID, policy_id: PolicyID, + env_id: EnvID, policy_id: PolicyID, t: int, init_obs: TensorType) -> None: # Make sure our mappings are up to date. agent_key = (episode.episode_id, agent_id) @@ -378,8 +388,9 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, self.agent_collectors[agent_key] = _AgentCollector() self.agent_collectors[agent_key].add_init_obs( episode_id=episode.episode_id, - agent_id=agent_id, + agent_index=episode._agent_index(agent_id), env_id=env_id, + t=t, init_obs=init_obs, view_requirements=view_reqs) @@ -429,7 +440,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ # Create the batch of data from the different buffers. data_col = view_req.data_col or view_col time_indices = \ - view_req.shift - ( + view_req.data_rel_pos - ( 1 if data_col in [SampleBatch.OBS, "t", "env_id", SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX] else 0) diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 6d6d532d57c6..31d9d6506bb4 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -272,9 +272,11 @@ def __init__( output_creator (Callable[[IOContext], OutputWriter]): Function that returns an OutputWriter object for saving generated experiences. - remote_worker_envs (bool): If using num_envs > 1, whether to create - those new envs in remote processes instead of in the current - process. This adds overheads, but can make sense if your envs + remote_worker_envs (bool): If using num_envs_per_worker > 1, + whether to create those new envs in remote processes instead of + in the current process. This adds overheads, but can make sense + if your envs are expensive to step/reset (e.g., for StarCraft). + Use this cautiously, overheads are significant! remote_env_batch_wait_ms (float): Timeout that remote workers are waiting when polling environments. 0 (continue when at least one env is ready) is a reasonable default, but optimal diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index cb54c9372faa..804c269aa749 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api( # Record transition info if applicable. if last_observation is None: _sample_collector.add_init_obs(episode, agent_id, env_id, - policy_id, filtered_obs) + policy_id, episode.length - 1, + filtered_obs) else: # Add actions, rewards, next-obs to collectors. values_dict = { @@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api( # Add initial obs to buffer. _sample_collector.add_init_obs( - new_episode, agent_id, env_id, policy_id, filtered_obs) + new_episode, agent_id, env_id, policy_id, + new_episode.length - 1, filtered_obs) item = PolicyEvalData( env_id, agent_id, filtered_obs, diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index bd2488c47ed6..3fc23289c129 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -59,7 +59,7 @@ def test_traj_view_normal_case(self): assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS - assert view_req_policy[key].shift == 1 + assert view_req_policy[key].data_rel_pos == 1 rollout_worker = trainer.workers.local_worker() sample_batch = rollout_worker.sample() expected_count = \ @@ -99,10 +99,10 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): if key == SampleBatch.PREV_ACTIONS: assert view_req_policy[key].data_col == SampleBatch.ACTIONS - assert view_req_policy[key].shift == -1 + assert view_req_policy[key].data_rel_pos == -1 elif key == SampleBatch.PREV_REWARDS: assert view_req_policy[key].data_col == SampleBatch.REWARDS - assert view_req_policy[key].shift == -1 + assert view_req_policy[key].data_rel_pos == -1 elif key not in [ SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS, SampleBatch.PREV_REWARDS @@ -110,7 +110,7 @@ def test_traj_view_lstm_prev_actions_and_rewards(self): assert view_req_policy[key].data_col is None else: assert view_req_policy[key].data_col == SampleBatch.OBS - assert view_req_policy[key].shift == 1 + assert view_req_policy[key].data_rel_pos == 1 trainer.stop() def test_traj_view_simple_performance(self): diff --git a/rllib/examples/policy/episode_env_aware_policy.py b/rllib/examples/policy/episode_env_aware_policy.py index ff8f91271479..89d4c525efbd 100644 --- a/rllib/examples/policy/episode_env_aware_policy.py +++ b/rllib/examples/policy/episode_env_aware_policy.py @@ -28,14 +28,16 @@ class _fake_model: "t": ViewRequirement(), SampleBatch.OBS: ViewRequirement(), SampleBatch.PREV_ACTIONS: ViewRequirement( - SampleBatch.ACTIONS, space=self.action_space, shift=-1), + SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1), SampleBatch.PREV_REWARDS: ViewRequirement( - SampleBatch.REWARDS, shift=-1), + SampleBatch.REWARDS, data_rel_pos=-1), } for i in range(2): self.model.inference_view_requirements["state_in_{}".format(i)] = \ ViewRequirement( - "state_out_{}".format(i), shift=-1, space=self.state_space) + "state_out_{}".format(i), + data_rel_pos=-1, + space=self.state_space) self.model.inference_view_requirements[ "state_out_{}".format(i)] = \ ViewRequirement(space=self.state_space) @@ -43,7 +45,7 @@ class _fake_model: self.view_requirements = dict( **{ SampleBatch.NEXT_OBS: ViewRequirement( - SampleBatch.OBS, shift=1), + SampleBatch.OBS, data_rel_pos=1), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.REWARDS: ViewRequirement(), SampleBatch.DONES: ViewRequirement(), diff --git a/rllib/examples/policy/rock_paper_scissors_dummies.py b/rllib/examples/policy/rock_paper_scissors_dummies.py index 72b2fdd518c9..011d49e5633e 100644 --- a/rllib/examples/policy/rock_paper_scissors_dummies.py +++ b/rllib/examples/policy/rock_paper_scissors_dummies.py @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs): self.view_requirements.update({ "state_in_0": ViewRequirement( "state_out_0", - shift=-1, + data_rel_pos=-1, space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32)) }) diff --git a/rllib/models/modelv2.py b/rllib/models/modelv2.py index 4ac047c59db2..21bc139d4d6c 100644 --- a/rllib/models/modelv2.py +++ b/rllib/models/modelv2.py @@ -61,7 +61,8 @@ def __init__(self, obs_space: gym.spaces.Space, self.time_major = self.model_config.get("_time_major") # Basic view requirement for all models: Use the observation as input. self.inference_view_requirements = { - SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space), + SampleBatch.OBS: ViewRequirement( + data_rel_pos=0, space=self.obs_space), } # TODO: (sven): Get rid of `get_initial_state` once Trajectory diff --git a/rllib/models/tf/recurrent_net.py b/rllib/models/tf/recurrent_net.py index f939c7ae36a6..57ecb22f05e0 100644 --- a/rllib/models/tf/recurrent_net.py +++ b/rllib/models/tf/recurrent_net.py @@ -178,10 +178,10 @@ def __init__(self, obs_space: gym.spaces.Space, if model_config["lstm_use_prev_action"]: self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, - shift=-1) + data_rel_pos=-1) if model_config["lstm_use_prev_reward"]: self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ - ViewRequirement(SampleBatch.REWARDS, shift=-1) + ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py index d558bf3dbf74..8e0d2263c7d2 100644 --- a/rllib/models/torch/recurrent_net.py +++ b/rllib/models/torch/recurrent_net.py @@ -159,10 +159,10 @@ def __init__(self, obs_space: gym.spaces.Space, if model_config["lstm_use_prev_action"]: self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \ ViewRequirement(SampleBatch.ACTIONS, space=self.action_space, - shift=-1) + data_rel_pos=-1) if model_config["lstm_use_prev_reward"]: self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \ - ViewRequirement(SampleBatch.REWARDS, shift=-1) + ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1) @override(RecurrentNetwork) def forward(self, input_dict: Dict[str, TensorType], diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 81a952018b08..a98c50f46497 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -564,13 +564,14 @@ def _get_default_view_requirements(self): SampleBatch.OBS: ViewRequirement(space=self.observation_space), SampleBatch.NEXT_OBS: ViewRequirement( data_col=SampleBatch.OBS, - shift=1, + data_rel_pos=1, space=self.observation_space), SampleBatch.ACTIONS: ViewRequirement(space=self.action_space), SampleBatch.REWARDS: ViewRequirement(), SampleBatch.DONES: ViewRequirement(), SampleBatch.INFOS: ViewRequirement(), SampleBatch.EPS_ID: ViewRequirement(), + SampleBatch.UNROLL_ID: ViewRequirement(), SampleBatch.AGENT_INDEX: ViewRequirement(), SampleBatch.UNROLL_ID: ViewRequirement(), "t": ViewRequirement(), @@ -617,7 +618,7 @@ def _initialize_loss_from_dummy_batch( batch_for_postproc.count = self._dummy_batch.count postprocessed_batch = self.postprocess_trajectory(batch_for_postproc) if state_outs: - B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size] + B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size] # TODO: (sven) This hack will not work for attention net traj. # view setup. i = 0 @@ -657,7 +658,8 @@ def _initialize_loss_from_dummy_batch( # Tag those only needed for post-processing. for key in batch_for_postproc.accessed_keys: if key not in train_batch.accessed_keys and \ - key in self.view_requirements: + key in self.view_requirements and \ + key not in self.model.inference_view_requirements: self.view_requirements[key].used_for_training = False # Remove those not needed at all (leave those that are needed # by Sampler to properly execute sample collection). @@ -680,18 +682,6 @@ def _initialize_loss_from_dummy_batch( "postprocessing function.".format(key)) else: del self.view_requirements[key] - # Add those data_cols (again) that are missing and have - # dependencies by view_cols. - for key in list(self.view_requirements.keys()): - vr = self.view_requirements[key] - if vr.data_col is not None and \ - vr.data_col not in self.view_requirements: - used_for_training = \ - vr.data_col in train_batch.accessed_keys - self.view_requirements[vr.data_col] = \ - ViewRequirement( - space=vr.space, - used_for_training=used_for_training) def _get_dummy_batch_from_view_requirements( self, batch_size: int = 1) -> SampleBatch: @@ -727,7 +717,7 @@ def _update_model_inference_view_requirements_from_init_state(self): model.inference_view_requirements["state_in_{}".format(i)] = \ ViewRequirement( "state_out_{}".format(i), - shift=-1, + data_rel_pos=-1, space=Box(-1.0, 1.0, shape=state.shape)) model.inference_view_requirements["state_out_{}".format(i)] = \ ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape)) diff --git a/rllib/policy/view_requirement.py b/rllib/policy/view_requirement.py index 3264b759b532..8813c3aca288 100644 --- a/rllib/policy/view_requirement.py +++ b/rllib/policy/view_requirement.py @@ -29,7 +29,7 @@ class ViewRequirement: def __init__(self, data_col: Optional[str] = None, space: gym.Space = None, - shift: Union[int, List[int]] = 0, + data_rel_pos: Union[int, List[int]] = 0, used_for_training: bool = True): """Initializes a ViewRequirement object. @@ -40,13 +40,14 @@ def __init__(self, space (gym.Space): The gym Space used in case we need to pad data in inaccessible areas of the trajectory (t<0 or t>H). Default: Simple box space, e.g. rewards. - shift (Union[int, List[int]]): Single shift value of list of - shift values to use relative to the underlying `data_col`. + data_rel_pos (Union[int, str, List[int]]): Single shift value or + list of relative positions to use (relative to the underlying + `data_col`). Example: For a view column "prev_actions", you can set - `data_col="actions"` and `shift=-1`. + `data_col="actions"` and `data_rel_pos=-1`. Example: For a view column "obs" in an Atari framestacking fashion, you can set `data_col="obs"` and - `shift=[-3, -2, -1, 0]`. + `data_rel_pos=[-3, -2, -1, 0]`. used_for_training (bool): Whether the data will be used for training. If False, the column will not be copied into the final train batch. @@ -54,5 +55,5 @@ def __init__(self, self.data_col = data_col self.space = space or gym.spaces.Box( float("-inf"), float("inf"), shape=()) - self.shift = shift + self.data_rel_pos = data_rel_pos self.used_for_training = used_for_training