Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Zachary Ravichandran committed Apr 27, 2022
1 parent 8df551f commit 1b2c782
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/rllib_policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def __init__(
self.feature_length = feature_length
self.fields = fields

def get_obs(self, obs: Dict[str, torch.Tensor], rnn_input: bool) -> torch.Tensor:
def get_obs(
self, obs: Dict[str, torch.Tensor], rnn_input: bool
) -> Union[List[torch.Tensor], Dict[str, torch.Tensor]]:
"""Read observations from input dictionary.
Parameters
Expand Down
84 changes: 67 additions & 17 deletions src/rllib_policies/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,14 @@ def __init__(
)

def get_dense_layers(self, layers: List[int]) -> torch.nn.Module:
"""Get set of linear layers.
"""Get subnetwork composed of linear layers.
Args:
layers (List[int]): Dense layer sizes. Implies
number of layers.
layers (List[int]): Number of features per layer.
Length in `layers` implies number of layers.
Returns:
torch.nn.Sequential: Pytorch module of dense layers.
"""
return torch.nn.Sequential(
*[torch.nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)]
Expand Down Expand Up @@ -319,14 +322,29 @@ def forward(
return graph_features


class AgentFrameGCN(GCN):
class ActionLayerPolicy(GCN):
def __init__(
self,
in_features: torch.Tensor,
in_features: int,
graph_conv_features: List[int],
embedding_size: List[int],
n_frame_nodes: int,
):
"""Action layer policy.
The action layer is composed of a fixed ring of nodes placed around
the agent and connected to the rest of the DSG given certain connectivity rules.
The policy learns features over an observed graph using a GNN.
Then only features from the action layer are mapped to a predicted action.
See https://arxiv.org/abs/2108.01176 for more details.
Args:
in_features (int): Number of input features per node.
graph_conv_features (List[int]): Number of channels per GCN.
embedding_size (List[int]): Length of leanred per-node features.
n_frame_nodes (int): Number of action layer nodes.
"""
super().__init__(in_features, graph_conv_features, embedding_size)
self.n_agent_frame_nodes = n_frame_nodes
if len(embedding_size) == 0:
Expand All @@ -343,6 +361,19 @@ def postprocess_padded_nodes(
def get_policy_features(
self, h: torch.Tensor, batch_index: torch.Tensor
) -> torch.Tensor:
"""Return features from the action layer.
Note that action layers features must be first in
the sequence of per-node learned features `h`.
Args:
h (torch.Tensor): Per-node learned features.
batch_index (torch.Tensor): Source batch index for
each feature in `h`.
Returns:
torch.Tensor: Output features derived from the action layer.
"""
# if no batch index is given, assume all data is from same graph
if batch_index is None:
batch_index = torch.zeros(h.shape[0], dtype=torch.int64).to(h.device)
Expand All @@ -356,33 +387,52 @@ def get_policy_features(
return h


class GNNBase(NetworkBase):
class ActionLayerBase(NetworkBase):
def __init__(
self,
fields: Dict[str, str],
in_features,
graph_conv_features,
embedding_size,
n_frame_nodes,
in_features: int,
graph_conv_features: List[int],
embedding_size: List[int],
n_frame_nodes: int,
):
net = AgentFrameGCN(
"""Wraps `ActionLayerPolicy` in `NetworkBase` interface.
Args:
fields (Dict[str, str]): Dictionary mapping observation field
to key in input observation dictionary.
in_features (int):Per-node input feature length.
graph_conv_features (List[int]): Per-layer GCN channel size.
embedding_size (List[int]): Length of learned per-node features.
n_frame_nodes (int): Number of action layer nodes.
"""
net = ActionLayerPolicy(
in_features, graph_conv_features, embedding_size, n_frame_nodes
)
super().__init__(net, net.out_features, fields.keys())
self.fields = fields

def get_obs(self, obs: Dict[str, torch.Tensor], rnn_input: bool) -> torch.Tensor:
def get_obs(
self, obs: Dict[str, torch.Tensor], rnn_input: bool
) -> Dict[str, torch.Tensor]:
"""Get required observations from input dictionary. Observations are
placed into a new dictionary and are reshaped, if needed.
Args:
obs (Dict[str, torch.Tensor]): Dictionary of input observations.
rnn_input (bool): True is this is an input to a recurrent network.
Returns:
Dict[str, torch.Tensor]: Dictionary of observations.
"""
graph_obs = {}

for field, obs_key in self.fields.items():
data = obs[obs_key]
if rnn_input: # combine time and batch inds
if rnn_input: # combine time and batch inds for RNN input
data = data.reshape((-1,) + tuple(data.shape[2:]))
if field in ("node_shapes", "edge_shapes"):
# if isinstance(data, torch.Tensor):
data = data.type(torch.int32)
# elif isinstance(data, np.ndarray):
# data = data.astype(np.int32)
graph_obs[field] = data

return graph_obs
Expand All @@ -406,7 +456,7 @@ def init_nets(
n_layer_nodes
):
return [
GNNBase(
ActionLayerBase(
fields, in_features, graph_conv_features, embedding_size, n_layer_nodes
)
]

0 comments on commit 1b2c782

Please sign in to comment.