Skip to content

Commit

Permalink
[RLlib] Attention Net prep PR #2: Smaller cleanups. (#12449)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Dec 1, 2020
1 parent e72147d commit 3ad9365
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 55 deletions.
12 changes: 8 additions & 4 deletions rllib/evaluation/collectors/sample_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class _SampleCollector(metaclass=ABCMeta):

@abstractmethod
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
policy_id: PolicyID, init_obs: TensorType) -> None:
policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
"""Adds an initial obs (after reset) to this collector.
Since the very first observation in an environment is collected w/o
Expand All @@ -48,6 +49,8 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
values for.
env_id (EnvID): The environment index (in a vectorized setup).
policy_id (PolicyID): Unique id for policy controlling the agent.
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): Initial observation (after env.reset()).
Examples:
Expand Down Expand Up @@ -172,9 +175,10 @@ def postprocess_episode(self,
MultiAgentBatch. Used for batch_mode=`complete_episodes`.
Returns:
Any: An ID that can be used in `build_multi_agent_batch` to
retrieve the samples that have been postprocessed as a
ready-built MultiAgentBatch.
Optional[MultiAgentBatch]: If `build` is True, the
SampleBatch or MultiAgentBatch built from `episode` (either
just from that episde or from the `_PolicyCollectorGroup`
in the `episode.batch_builder` property).
"""
raise NotImplementedError

Expand Down
31 changes: 21 additions & 10 deletions rllib/evaluation/collectors/simple_list_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ def __init__(self, shift_before: int = 0):
# each time a (non-initial!) observation is added.
self.count = 0

def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
env_id: EnvID, init_obs: TensorType,
def add_init_obs(self, episode_id: EpisodeID, agent_index: int,
env_id: EnvID, t: int, init_obs: TensorType,
view_requirements: Dict[str, ViewRequirement]) -> None:
"""Adds an initial observation (after reset) to the Agent's trajectory.
Args:
episode_id (EpisodeID): Unique ID for the episode we are adding the
initial observation for.
agent_id (AgentID): Unique ID for the agent we are adding the
initial observation for.
agent_index (int): Unique int index (starting from 0) for the agent
within its episode.
env_id (EnvID): The environment index (in a vectorized setup).
t (int): The time step (episode length - 1). The initial obs has
ts=-1(!), then an action/reward/next-obs at t=0, etc..
init_obs (TensorType): The initial observation tensor (after
`env.reset()`).
view_requirements (Dict[str, ViewRequirements])
Expand All @@ -72,10 +74,15 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID,
single_row={
SampleBatch.OBS: init_obs,
SampleBatch.EPS_ID: episode_id,
SampleBatch.AGENT_INDEX: agent_id,
SampleBatch.AGENT_INDEX: agent_index,
"env_id": env_id,
"t": t,
})
self.buffers[SampleBatch.OBS].append(init_obs)
self.buffers[SampleBatch.EPS_ID].append(episode_id)
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index)
self.buffers["env_id"].append(env_id)
self.buffers["t"].append(t)

def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \
None:
Expand Down Expand Up @@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \
continue
# OBS are already shifted by -1 (the initial obs starts one ts
# before all other data columns).
shift = view_req.shift - \
shift = view_req.data_rel_pos - \
(1 if data_col == SampleBatch.OBS else 0)
if data_col not in np_data:
np_data[data_col] = to_float_np_array(self.buffers[data_col])
Expand Down Expand Up @@ -187,7 +194,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None:
for col, data in single_row.items():
if col in self.buffers:
continue
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0)
shift = self.shift_before - (1 if col in [
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX,
"env_id", "t"
] else 0)
# Python primitive or dict (e.g. INFOs).
if isinstance(data, (int, float, bool, str, dict)):
self.buffers[col] = [0 for _ in range(shift)]
Expand Down Expand Up @@ -360,7 +370,7 @@ def episode_step(self, episode_id: EpisodeID) -> None:

@override(_SampleCollector)
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
env_id: EnvID, policy_id: PolicyID,
env_id: EnvID, policy_id: PolicyID, t: int,
init_obs: TensorType) -> None:
# Make sure our mappings are up to date.
agent_key = (episode.episode_id, agent_id)
Expand All @@ -378,8 +388,9 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID,
self.agent_collectors[agent_key] = _AgentCollector()
self.agent_collectors[agent_key].add_init_obs(
episode_id=episode.episode_id,
agent_id=agent_id,
agent_index=episode._agent_index(agent_id),
env_id=env_id,
t=t,
init_obs=init_obs,
view_requirements=view_reqs)

Expand Down Expand Up @@ -429,7 +440,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \
# Create the batch of data from the different buffers.
data_col = view_req.data_col or view_col
time_indices = \
view_req.shift - (
view_req.data_rel_pos - (
1 if data_col in [SampleBatch.OBS, "t", "env_id",
SampleBatch.EPS_ID,
SampleBatch.AGENT_INDEX] else 0)
Expand Down
8 changes: 5 additions & 3 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,11 @@ def __init__(
output_creator (Callable[[IOContext], OutputWriter]): Function that
returns an OutputWriter object for saving generated
experiences.
remote_worker_envs (bool): If using num_envs > 1, whether to create
those new envs in remote processes instead of in the current
process. This adds overheads, but can make sense if your envs
remote_worker_envs (bool): If using num_envs_per_worker > 1,
whether to create those new envs in remote processes instead of
in the current process. This adds overheads, but can make sense
if your envs are expensive to step/reset (e.g., for StarCraft).
Use this cautiously, overheads are significant!
remote_env_batch_wait_ms (float): Timeout that remote workers
are waiting when polling environments. 0 (continue when at
least one env is ready) is a reasonable default, but optimal
Expand Down
6 changes: 4 additions & 2 deletions rllib/evaluation/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,8 @@ def _process_observations_w_trajectory_view_api(
# Record transition info if applicable.
if last_observation is None:
_sample_collector.add_init_obs(episode, agent_id, env_id,
policy_id, filtered_obs)
policy_id, episode.length - 1,
filtered_obs)
else:
# Add actions, rewards, next-obs to collectors.
values_dict = {
Expand Down Expand Up @@ -1158,7 +1159,8 @@ def _process_observations_w_trajectory_view_api(

# Add initial obs to buffer.
_sample_collector.add_init_obs(
new_episode, agent_id, env_id, policy_id, filtered_obs)
new_episode, agent_id, env_id, policy_id,
new_episode.length - 1, filtered_obs)

item = PolicyEvalData(
env_id, agent_id, filtered_obs,
Expand Down
8 changes: 4 additions & 4 deletions rllib/evaluation/tests/test_trajectory_view_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_traj_view_normal_case(self):
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
assert view_req_policy[key].data_rel_pos == 1
rollout_worker = trainer.workers.local_worker()
sample_batch = rollout_worker.sample()
expected_count = \
Expand Down Expand Up @@ -99,18 +99,18 @@ def test_traj_view_lstm_prev_actions_and_rewards(self):

if key == SampleBatch.PREV_ACTIONS:
assert view_req_policy[key].data_col == SampleBatch.ACTIONS
assert view_req_policy[key].shift == -1
assert view_req_policy[key].data_rel_pos == -1
elif key == SampleBatch.PREV_REWARDS:
assert view_req_policy[key].data_col == SampleBatch.REWARDS
assert view_req_policy[key].shift == -1
assert view_req_policy[key].data_rel_pos == -1
elif key not in [
SampleBatch.NEXT_OBS, SampleBatch.PREV_ACTIONS,
SampleBatch.PREV_REWARDS
]:
assert view_req_policy[key].data_col is None
else:
assert view_req_policy[key].data_col == SampleBatch.OBS
assert view_req_policy[key].shift == 1
assert view_req_policy[key].data_rel_pos == 1
trainer.stop()

def test_traj_view_simple_performance(self):
Expand Down
10 changes: 6 additions & 4 deletions rllib/examples/policy/episode_env_aware_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,24 @@ class _fake_model:
"t": ViewRequirement(),
SampleBatch.OBS: ViewRequirement(),
SampleBatch.PREV_ACTIONS: ViewRequirement(
SampleBatch.ACTIONS, space=self.action_space, shift=-1),
SampleBatch.ACTIONS, space=self.action_space, data_rel_pos=-1),
SampleBatch.PREV_REWARDS: ViewRequirement(
SampleBatch.REWARDS, shift=-1),
SampleBatch.REWARDS, data_rel_pos=-1),
}
for i in range(2):
self.model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i), shift=-1, space=self.state_space)
"state_out_{}".format(i),
data_rel_pos=-1,
space=self.state_space)
self.model.inference_view_requirements[
"state_out_{}".format(i)] = \
ViewRequirement(space=self.state_space)

self.view_requirements = dict(
**{
SampleBatch.NEXT_OBS: ViewRequirement(
SampleBatch.OBS, shift=1),
SampleBatch.OBS, data_rel_pos=1),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
Expand Down
2 changes: 1 addition & 1 deletion rllib/examples/policy/rock_paper_scissors_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, *args, **kwargs):
self.view_requirements.update({
"state_in_0": ViewRequirement(
"state_out_0",
shift=-1,
data_rel_pos=-1,
space=gym.spaces.Box(0, 100, shape=(), dtype=np.int32))
})

Expand Down
3 changes: 2 additions & 1 deletion rllib/models/modelv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(self, obs_space: gym.spaces.Space,
self.time_major = self.model_config.get("_time_major")
# Basic view requirement for all models: Use the observation as input.
self.inference_view_requirements = {
SampleBatch.OBS: ViewRequirement(shift=0, space=self.obs_space),
SampleBatch.OBS: ViewRequirement(
data_rel_pos=0, space=self.obs_space),
}

# TODO: (sven): Get rid of `get_initial_state` once Trajectory
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/recurrent_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def __init__(self, obs_space: gym.spaces.Space,
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)

@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/recurrent_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,10 @@ def __init__(self, obs_space: gym.spaces.Space,
if model_config["lstm_use_prev_action"]:
self.inference_view_requirements[SampleBatch.PREV_ACTIONS] = \
ViewRequirement(SampleBatch.ACTIONS, space=self.action_space,
shift=-1)
data_rel_pos=-1)
if model_config["lstm_use_prev_reward"]:
self.inference_view_requirements[SampleBatch.PREV_REWARDS] = \
ViewRequirement(SampleBatch.REWARDS, shift=-1)
ViewRequirement(SampleBatch.REWARDS, data_rel_pos=-1)

@override(RecurrentNetwork)
def forward(self, input_dict: Dict[str, TensorType],
Expand Down
22 changes: 6 additions & 16 deletions rllib/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,13 +564,14 @@ def _get_default_view_requirements(self):
SampleBatch.OBS: ViewRequirement(space=self.observation_space),
SampleBatch.NEXT_OBS: ViewRequirement(
data_col=SampleBatch.OBS,
shift=1,
data_rel_pos=1,
space=self.observation_space),
SampleBatch.ACTIONS: ViewRequirement(space=self.action_space),
SampleBatch.REWARDS: ViewRequirement(),
SampleBatch.DONES: ViewRequirement(),
SampleBatch.INFOS: ViewRequirement(),
SampleBatch.EPS_ID: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
SampleBatch.AGENT_INDEX: ViewRequirement(),
SampleBatch.UNROLL_ID: ViewRequirement(),
"t": ViewRequirement(),
Expand Down Expand Up @@ -617,7 +618,7 @@ def _initialize_loss_from_dummy_batch(
batch_for_postproc.count = self._dummy_batch.count
postprocessed_batch = self.postprocess_trajectory(batch_for_postproc)
if state_outs:
B = 4 # For RNNs, have B=2, T=[depends on sample_batch_size]
B = 4 # For RNNs, have B=4, T=[depends on sample_batch_size]
# TODO: (sven) This hack will not work for attention net traj.
# view setup.
i = 0
Expand Down Expand Up @@ -657,7 +658,8 @@ def _initialize_loss_from_dummy_batch(
# Tag those only needed for post-processing.
for key in batch_for_postproc.accessed_keys:
if key not in train_batch.accessed_keys and \
key in self.view_requirements:
key in self.view_requirements and \
key not in self.model.inference_view_requirements:
self.view_requirements[key].used_for_training = False
# Remove those not needed at all (leave those that are needed
# by Sampler to properly execute sample collection).
Expand All @@ -680,18 +682,6 @@ def _initialize_loss_from_dummy_batch(
"postprocessing function.".format(key))
else:
del self.view_requirements[key]
# Add those data_cols (again) that are missing and have
# dependencies by view_cols.
for key in list(self.view_requirements.keys()):
vr = self.view_requirements[key]
if vr.data_col is not None and \
vr.data_col not in self.view_requirements:
used_for_training = \
vr.data_col in train_batch.accessed_keys
self.view_requirements[vr.data_col] = \
ViewRequirement(
space=vr.space,
used_for_training=used_for_training)

def _get_dummy_batch_from_view_requirements(
self, batch_size: int = 1) -> SampleBatch:
Expand Down Expand Up @@ -727,7 +717,7 @@ def _update_model_inference_view_requirements_from_init_state(self):
model.inference_view_requirements["state_in_{}".format(i)] = \
ViewRequirement(
"state_out_{}".format(i),
shift=-1,
data_rel_pos=-1,
space=Box(-1.0, 1.0, shape=state.shape))
model.inference_view_requirements["state_out_{}".format(i)] = \
ViewRequirement(space=Box(-1.0, 1.0, shape=state.shape))
Expand Down
13 changes: 7 additions & 6 deletions rllib/policy/view_requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ViewRequirement:
def __init__(self,
data_col: Optional[str] = None,
space: gym.Space = None,
shift: Union[int, List[int]] = 0,
data_rel_pos: Union[int, List[int]] = 0,
used_for_training: bool = True):
"""Initializes a ViewRequirement object.
Expand All @@ -40,19 +40,20 @@ def __init__(self,
space (gym.Space): The gym Space used in case we need to pad data
in inaccessible areas of the trajectory (t<0 or t>H).
Default: Simple box space, e.g. rewards.
shift (Union[int, List[int]]): Single shift value of list of
shift values to use relative to the underlying `data_col`.
data_rel_pos (Union[int, str, List[int]]): Single shift value or
list of relative positions to use (relative to the underlying
`data_col`).
Example: For a view column "prev_actions", you can set
`data_col="actions"` and `shift=-1`.
`data_col="actions"` and `data_rel_pos=-1`.
Example: For a view column "obs" in an Atari framestacking
fashion, you can set `data_col="obs"` and
`shift=[-3, -2, -1, 0]`.
`data_rel_pos=[-3, -2, -1, 0]`.
used_for_training (bool): Whether the data will be used for
training. If False, the column will not be copied into the
final train batch.
"""
self.data_col = data_col
self.space = space or gym.spaces.Box(
float("-inf"), float("inf"), shape=())
self.shift = shift
self.data_rel_pos = data_rel_pos
self.used_for_training = used_for_training

0 comments on commit 3ad9365

Please sign in to comment.