diff --git a/src/rllib_policies/base.py b/src/rllib_policies/base.py index d195558..39a19fb 100644 --- a/src/rllib_policies/base.py +++ b/src/rllib_policies/base.py @@ -60,7 +60,9 @@ def __init__( self.feature_length = feature_length self.fields = fields - def get_obs(self, obs: Dict[str, torch.Tensor], rnn_input: bool) -> torch.Tensor: + def get_obs( + self, obs: Dict[str, torch.Tensor], rnn_input: bool + ) -> Union[List[torch.Tensor], Dict[str, torch.Tensor]]: """Read observations from input dictionary. Parameters diff --git a/src/rllib_policies/gnn.py b/src/rllib_policies/gnn.py index 62c7d78..8d5e0d9 100644 --- a/src/rllib_policies/gnn.py +++ b/src/rllib_policies/gnn.py @@ -75,11 +75,14 @@ def __init__( ) def get_dense_layers(self, layers: List[int]) -> torch.nn.Module: - """Get set of linear layers. + """Get subnetwork composed of linear layers. Args: - layers (List[int]): Dense layer sizes. Implies - number of layers. + layers (List[int]): Number of features per layer. + Length in `layers` implies number of layers. + + Returns: + torch.nn.Sequential: Pytorch module of dense layers. """ return torch.nn.Sequential( *[torch.nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)] @@ -319,14 +322,29 @@ def forward( return graph_features -class AgentFrameGCN(GCN): +class ActionLayerPolicy(GCN): def __init__( self, - in_features: torch.Tensor, + in_features: int, graph_conv_features: List[int], embedding_size: List[int], n_frame_nodes: int, ): + """Action layer policy. + + The action layer is composed of a fixed ring of nodes placed around + the agent and connected to the rest of the DSG given certain connectivity rules. + The policy learns features over an observed graph using a GNN. + Then only features from the action layer are mapped to a predicted action. + + See https://arxiv.org/abs/2108.01176 for more details. + + Args: + in_features (int): Number of input features per node. + graph_conv_features (List[int]): Number of channels per GCN. + embedding_size (List[int]): Length of leanred per-node features. + n_frame_nodes (int): Number of action layer nodes. + """ super().__init__(in_features, graph_conv_features, embedding_size) self.n_agent_frame_nodes = n_frame_nodes if len(embedding_size) == 0: @@ -343,6 +361,19 @@ def postprocess_padded_nodes( def get_policy_features( self, h: torch.Tensor, batch_index: torch.Tensor ) -> torch.Tensor: + """Return features from the action layer. + + Note that action layers features must be first in + the sequence of per-node learned features `h`. + + Args: + h (torch.Tensor): Per-node learned features. + batch_index (torch.Tensor): Source batch index for + each feature in `h`. + + Returns: + torch.Tensor: Output features derived from the action layer. + """ # if no batch index is given, assume all data is from same graph if batch_index is None: batch_index = torch.zeros(h.shape[0], dtype=torch.int64).to(h.device) @@ -356,33 +387,52 @@ def get_policy_features( return h -class GNNBase(NetworkBase): +class ActionLayerBase(NetworkBase): def __init__( self, fields: Dict[str, str], - in_features, - graph_conv_features, - embedding_size, - n_frame_nodes, + in_features: int, + graph_conv_features: List[int], + embedding_size: List[int], + n_frame_nodes: int, ): - net = AgentFrameGCN( + """Wraps `ActionLayerPolicy` in `NetworkBase` interface. + + Args: + fields (Dict[str, str]): Dictionary mapping observation field + to key in input observation dictionary. + in_features (int):Per-node input feature length. + graph_conv_features (List[int]): Per-layer GCN channel size. + embedding_size (List[int]): Length of learned per-node features. + n_frame_nodes (int): Number of action layer nodes. + """ + net = ActionLayerPolicy( in_features, graph_conv_features, embedding_size, n_frame_nodes ) super().__init__(net, net.out_features, fields.keys()) self.fields = fields - def get_obs(self, obs: Dict[str, torch.Tensor], rnn_input: bool) -> torch.Tensor: + def get_obs( + self, obs: Dict[str, torch.Tensor], rnn_input: bool + ) -> Dict[str, torch.Tensor]: + """Get required observations from input dictionary. Observations are + placed into a new dictionary and are reshaped, if needed. + + Args: + obs (Dict[str, torch.Tensor]): Dictionary of input observations. + rnn_input (bool): True is this is an input to a recurrent network. + + Returns: + Dict[str, torch.Tensor]: Dictionary of observations. + """ graph_obs = {} for field, obs_key in self.fields.items(): data = obs[obs_key] - if rnn_input: # combine time and batch inds + if rnn_input: # combine time and batch inds for RNN input data = data.reshape((-1,) + tuple(data.shape[2:])) if field in ("node_shapes", "edge_shapes"): - # if isinstance(data, torch.Tensor): data = data.type(torch.int32) - # elif isinstance(data, np.ndarray): - # data = data.astype(np.int32) graph_obs[field] = data return graph_obs @@ -406,7 +456,7 @@ def init_nets( n_layer_nodes ): return [ - GNNBase( + ActionLayerBase( fields, in_features, graph_conv_features, embedding_size, n_layer_nodes ) ]