-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Attention Net prep PR #2: Smaller cleanups. #12449
Changes from all commits
b5a4bc1
0437680
86911d9
5e269c4
787810d
7113c32
2040c93
f79faf7
b5a31b3
bc084a2
7241d82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,17 +52,19 @@ def __init__(self, shift_before: int = 0): | |
# each time a (non-initial!) observation is added. | ||
self.count = 0 | ||
|
||
def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, | ||
env_id: EnvID, init_obs: TensorType, | ||
def add_init_obs(self, episode_id: EpisodeID, agent_index: int, | ||
env_id: EnvID, t: int, init_obs: TensorType, | ||
view_requirements: Dict[str, ViewRequirement]) -> None: | ||
"""Adds an initial observation (after reset) to the Agent's trajectory. | ||
|
||
Args: | ||
episode_id (EpisodeID): Unique ID for the episode we are adding the | ||
initial observation for. | ||
agent_id (AgentID): Unique ID for the agent we are adding the | ||
initial observation for. | ||
agent_index (int): Unique int index (starting from 0) for the agent | ||
within its episode. | ||
env_id (EnvID): The environment index (in a vectorized setup). | ||
t (int): The time step (episode length - 1). The initial obs has | ||
ts=-1(!), then an action/reward/next-obs at t=0, etc.. | ||
init_obs (TensorType): The initial observation tensor (after | ||
`env.reset()`). | ||
view_requirements (Dict[str, ViewRequirements]) | ||
|
@@ -72,10 +74,15 @@ def add_init_obs(self, episode_id: EpisodeID, agent_id: AgentID, | |
single_row={ | ||
SampleBatch.OBS: init_obs, | ||
SampleBatch.EPS_ID: episode_id, | ||
SampleBatch.AGENT_INDEX: agent_id, | ||
SampleBatch.AGENT_INDEX: agent_index, | ||
"env_id": env_id, | ||
"t": t, | ||
}) | ||
self.buffers[SampleBatch.OBS].append(init_obs) | ||
self.buffers[SampleBatch.EPS_ID].append(episode_id) | ||
self.buffers[SampleBatch.AGENT_INDEX].append(agent_index) | ||
self.buffers["env_id"].append(env_id) | ||
self.buffers["t"].append(t) | ||
|
||
def add_action_reward_next_obs(self, values: Dict[str, TensorType]) -> \ | ||
None: | ||
|
@@ -133,7 +140,7 @@ def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ | |
continue | ||
# OBS are already shifted by -1 (the initial obs starts one ts | ||
# before all other data columns). | ||
shift = view_req.shift - \ | ||
shift = view_req.data_rel_pos - \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed this b/c this will support (in the upcoming PRs) not just a single shift (int), but also:
|
||
(1 if data_col == SampleBatch.OBS else 0) | ||
if data_col not in np_data: | ||
np_data[data_col] = to_float_np_array(self.buffers[data_col]) | ||
|
@@ -187,7 +194,10 @@ def _build_buffers(self, single_row: Dict[str, TensorType]) -> None: | |
for col, data in single_row.items(): | ||
if col in self.buffers: | ||
continue | ||
shift = self.shift_before - (1 if col == SampleBatch.OBS else 0) | ||
shift = self.shift_before - (1 if col in [ | ||
SampleBatch.OBS, SampleBatch.EPS_ID, SampleBatch.AGENT_INDEX, | ||
"env_id", "t" | ||
] else 0) | ||
# Python primitive or dict (e.g. INFOs). | ||
if isinstance(data, (int, float, bool, str, dict)): | ||
self.buffers[col] = [0 for _ in range(shift)] | ||
|
@@ -360,7 +370,7 @@ def episode_step(self, episode_id: EpisodeID) -> None: | |
|
||
@override(_SampleCollector) | ||
def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, | ||
env_id: EnvID, policy_id: PolicyID, | ||
env_id: EnvID, policy_id: PolicyID, t: int, | ||
init_obs: TensorType) -> None: | ||
# Make sure our mappings are up to date. | ||
agent_key = (episode.episode_id, agent_id) | ||
|
@@ -378,8 +388,9 @@ def add_init_obs(self, episode: MultiAgentEpisode, agent_id: AgentID, | |
self.agent_collectors[agent_key] = _AgentCollector() | ||
self.agent_collectors[agent_key].add_init_obs( | ||
episode_id=episode.episode_id, | ||
agent_id=agent_id, | ||
agent_index=episode._agent_index(agent_id), | ||
env_id=env_id, | ||
t=t, | ||
init_obs=init_obs, | ||
view_requirements=view_reqs) | ||
|
||
|
@@ -429,7 +440,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ | |
# Create the batch of data from the different buffers. | ||
data_col = view_req.data_col or view_col | ||
time_indices = \ | ||
view_req.shift - ( | ||
view_req.data_rel_pos - ( | ||
1 if data_col in [SampleBatch.OBS, "t", "env_id", | ||
SampleBatch.EPS_ID, | ||
SampleBatch.AGENT_INDEX] else 0) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,9 +272,11 @@ def __init__( | |
output_creator (Callable[[IOContext], OutputWriter]): Function that | ||
returns an OutputWriter object for saving generated | ||
experiences. | ||
remote_worker_envs (bool): If using num_envs > 1, whether to create | ||
those new envs in remote processes instead of in the current | ||
process. This adds overheads, but can make sense if your envs | ||
remote_worker_envs (bool): If using num_envs_per_worker > 1, | ||
whether to create those new envs in remote processes instead of | ||
in the current process. This adds overheads, but can make sense | ||
if your envs are expensive to step/reset (e.g., for StarCraft). | ||
Use this cautiously, overheads are significant! | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
remote_env_batch_wait_ms (float): Timeout that remote workers | ||
are waiting when polling environments. 0 (continue when at | ||
least one env is ready) is a reasonable default, but optimal | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ class ViewRequirement: | |
def __init__(self, | ||
data_col: Optional[str] = None, | ||
space: gym.Space = None, | ||
shift: Union[int, List[int]] = 0, | ||
data_rel_pos: Union[int, List[int]] = 0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not keep it as shift? It seems to be intuitive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I liked |
||
used_for_training: bool = True): | ||
"""Initializes a ViewRequirement object. | ||
|
||
|
@@ -40,19 +40,20 @@ def __init__(self, | |
space (gym.Space): The gym Space used in case we need to pad data | ||
in inaccessible areas of the trajectory (t<0 or t>H). | ||
Default: Simple box space, e.g. rewards. | ||
shift (Union[int, List[int]]): Single shift value of list of | ||
shift values to use relative to the underlying `data_col`. | ||
data_rel_pos (Union[int, str, List[int]]): Single shift value or | ||
list of relative positions to use (relative to the underlying | ||
`data_col`). | ||
Example: For a view column "prev_actions", you can set | ||
`data_col="actions"` and `shift=-1`. | ||
`data_col="actions"` and `data_rel_pos=-1`. | ||
Example: For a view column "obs" in an Atari framestacking | ||
fashion, you can set `data_col="obs"` and | ||
`shift=[-3, -2, -1, 0]`. | ||
`data_rel_pos=[-3, -2, -1, 0]`. | ||
used_for_training (bool): Whether the data will be used for | ||
training. If False, the column will not be copied into the | ||
final train batch. | ||
""" | ||
self.data_col = data_col | ||
self.space = space or gym.spaces.Box( | ||
float("-inf"), float("inf"), shape=()) | ||
self.shift = shift | ||
self.data_rel_pos = data_rel_pos | ||
self.used_for_training = used_for_training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.