Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Attention Net prep PR #2: Smaller cleanups. #12449

Merged
merged 11 commits into from
Dec 1, 2020
12 changes: 8 additions & 4 deletions rllib/evaluation/collectors/sample_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):

@abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
policy_id: PolicyID, init_obs: TensorType) -> None:
policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector.

Since the very first observation in an environment is collected w/o
Expand All @@ -48,6 +49,8 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
values for.
env_id (EnvID): The environment index (in a vectorized setup).
policy_id (PolicyID): Unique id for policy controlling the agent.
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): Initial observation (after env.reset()).

Examples:
Expand Down Expand Up @@ -172,9 +175,10 @@ def postprocess_episode(self,
MultiAgentBatch. Used for batch_mode=`complete_episodes`.

Returns:
Any: An ID that can be used in `build_multi_agent_batch` to
retrieve the samples that have been postprocessed as a
ready-built MultiAgentBatch.
Optional[MultiAgentBatch]: If `build` is True, the
SampleBatch or MultiAgentBatch built from `episode` (either
just from that episde or from the `_PolicyCollectorGroup`
in the `episode.batch_builder` property).
"""
raise NotImplementedError

Expand Down
31 changes: 21 additions & 10 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ def __init__(self, shift_before: int = 0):
# each time a (non-initial!) observation is added.
self.count = 0

def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, init_obs: TensorType,
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • agent_id vs agent_idx was a bug
  • added timestep

env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.

Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for.
agent_id (AgentID): Unique ID for the agent we are adding the
initial observation for.
agent_index (int): Unique int index (starting from 0) for the agent
within its episode.
env_id (EnvID): The environment index (in a vectorized setup).
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): The initial observation tensor (after
`env.reset()`).
view_requirements (Dict[str, ViewRequirements])
Expand All @@ -72,10 +74,15 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id,
SampleBatch.AGENT_INDEX: agent_index,
"env_id": env_id,
"t": t,
})
self.buffers[SampleBatch.OBS].append(init_obs)
self.buffers[SampleBatch.EPS_ID].append(episode_id)
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers["env_id"].append(env_id)
self.buffers["t"].append(t)

def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None:
Expand Down Expand Up @@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
continue
# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
shift = view_req.shift - \
shift = view_req.data_rel_pos - \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed this b/c this will support (in the upcoming PRs) not just a single shift (int), but also:

  • list of ints (include not just one ts in this view, but several)
  • a range string, e.g. "-50:-1" (will be used by attention nets and Atari framestacking).

(1 if data_col == SampleBatch.OBS else 0)
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
Expand Down Expand Up @@ -187,7 +194,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
for col, data in single_row.items():
if col in self.buffers:
continue
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
"env_id", "t"
] else 0)
# Python primitive or dict (e.g. INFOs).
if isinstance(data, (int, float, bool, str, dict)):
self.buffers[col] = [0 for _ in range(shift)]
Expand Down Expand Up @@ -360,7 +370,7 @@ def episode_step(self, episode_id: EpisodeID) -> None:

@override(_SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID,
env_id: EnvID, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
# Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id)
Expand All @@ -378,8 +388,9 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_id=agent_id,
agent_index=episode._agent_index(agent_id),
env_id=env_id,
t=t,
init_obs=init_obs,
view_requirements=view_reqs)

Expand Down Expand Up @@ -429,7 +440,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.shift - (
view_req.data_rel_pos - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX] else 0)
Expand Down
8 changes: 5 additions & 3 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,11 @@ def __init__(
output_creator (Callable[[IOContext], OutputWriter]): Function that
returns an OutputWriter object for saving generated
experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
remote_worker_envs (bool): If using num_envs_per_worker > 1,
whether to create those new envs in remote processes instead of
in the current process. This adds overheads, but can make sense
if your envs are expensive to step/reset (e.g., for StarCraft).
Use this cautiously, overheads are significant!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
Expand Down
6 changes: 4 additions & 2 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
# Record transition info if applicable.
if last_observation is None:
_sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, filtered_obs)
policy_id, episode.length - 1,
filtered_obs)
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
Expand Down Expand Up @@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(

# Add initial obs to buffer.
_sample_collector.add_init_obs(
new_episode, agent_id, env_id, policy_id, filtered_obs)
new_episode, agent_id, env_id, policy_id,
new_episode.length - 1, filtered_obs)

item = PolicyEvalData(
env_id, agent_id, filtered_obs,
Expand Down
8 changes: 4 additions & 4 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_traj_view_normal_case(self):
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
assert view_req_policy[key].data_rel_pos == 1
rollout_worker = trainer.workers.local_worker()
sample_batch = rollout_worker.sample()
expected_count = \
Expand Down Expand Up @@ -99,18 +99,18 @@ def test_traj_view_lstm_prev_actions_and_rewards(self):

if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].shift == -1
assert view_req_policy[key].data_rel_pos == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].shift == -1
assert view_req_policy[key].data_rel_pos == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
]:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
assert view_req_policy[key].data_rel_pos == 1
trainer.stop()

def test_traj_view_simple_performance(self):
Expand Down
10 changes: 6 additions & 4 deletions rllib/examples/policy/episode_env_aware_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,24 @@ class _fake_model:
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
SampleBatch.REWARDS, data_rel_pos=-1),
}
for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i), shift=-1, space=self.state_space)
"state_out_{}".format(i),
data_rel_pos=-1,
space=self.state_space)
self.model.inference_view_requirements[
"state_out_{}".format(i)] = \
ViewRequirement(space=self.state_space)

self.view_requirements = dict(
**{
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1),
SampleBatch.OBS, data_rel_pos=1),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/policy/rock_paper_scissors_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs):
self.view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
shift=-1,
data_rel_pos=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
})

Expand Down
3 changes: 2 additions & 1 deletion rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(self, obs_space: gym.spaces.Space,
self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input.
self.inference_view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
SampleBatch.OBS: ViewRequirement(
data_rel_pos=0, space=self.obs_space),
}

# TODO: (sven): Get rid of `get_initial_state` once Trajectory
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/recurrent_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def __init__(self, obs_space: gym.spaces.Space,
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)

@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/recurrent_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def __init__(self, obs_space: gym.spaces.Space,
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)

@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],
Expand Down
22 changes: 6 additions & 16 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,14 @@ def _get_default_view_requirements(self):
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
SampleBatch.NEXT_OBS: ViewRequirement(
data_col=SampleBatch.OBS,
shift=1,
data_rel_pos=1,
space=self.observation_space),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(),
Expand Down Expand Up @@ -617,7 +618,7 @@ def _initialize_loss_from_dummy_batch(
batch_for_postproc.count = self._dummy_batch.count
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
if state_outs:
B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size]
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
# TODO: (sven) This hack will not work for attention net traj.
# view setup.
i = 0
Expand Down Expand Up @@ -657,7 +658,8 @@ def _initialize_loss_from_dummy_batch(
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \
key in self.view_requirements:
key in self.view_requirements and \
key not in self.model.inference_view_requirements:
self.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
Expand All @@ -680,18 +682,6 @@ def _initialize_loss_from_dummy_batch(
"postprocessing function.".format(key))
else:
del self.view_requirements[key]
# Add those data_cols (again) that are missing and have
# dependencies by view_cols.
for key in list(self.view_requirements.keys()):
vr = self.view_requirements[key]
if vr.data_col is not None and \
vr.data_col not in self.view_requirements:
used_for_training = \
vr.data_col in train_batch.accessed_keys
self.view_requirements[vr.data_col] = \
ViewRequirement(
space=vr.space,
used_for_training=used_for_training)

def _get_dummy_batch_from_view_requirements(
self, batch_size: int = 1) -> SampleBatch:
Expand Down Expand Up @@ -727,7 +717,7 @@ def _update_model_inference_view_requirements_from_init_state(self):
model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
data_rel_pos=-1,
space=Box(-1.0, 1.0, shape=state.shape))
model.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
Expand Down
13 changes: 7 additions & 6 deletions rllib/policy/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
shift: Union[int, List[int]] = 0,
data_rel_pos: Union[int, List[int]] = 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keep it as shift? It seems to be intuitive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liked shift, too. The problem is, there will also be an abs_pos soon (see attention net PRs). So I wanted to distinguish between these two concepts.

used_for_training: bool = True):
"""Initializes a ViewRequirement object.

Expand All @@ -40,19 +40,20 @@ def __init__(self,
space (gym.Space): The gym Space used in case we need to pad data
in inaccessible areas of the trajectory (t<0 or t>H).
Default: Simple box space, e.g. rewards.
shift (Union[int, List[int]]): Single shift value of list of
shift values to use relative to the underlying `data_col`.
data_rel_pos (Union[int, str, List[int]]): Single shift value or
list of relative positions to use (relative to the underlying
`data_col`).
Example: For a view column "prev_actions", you can set
`data_col="actions"` and `shift=-1`.
`data_col="actions"` and `data_rel_pos=-1`.
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`.
`data_rel_pos=[-3, -2, -1, 0]`.
used_for_training (bool): Whether the data will be used for
training. If False, the column will not be copied into the
final train batch.
"""
self.data_col = data_col
self.space = space or gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.shift = shift
self.data_rel_pos = data_rel_pos
self.used_for_training = used_for_training