Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Zachary Ravichandran committed Apr 27, 2022
1 parent f895389 commit 8df551f
Showing 1 changed file with 81 additions and 4 deletions.
85 changes: 81 additions & 4 deletions src/rllib_policies/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def __init__(
],
**network_args: Dict[str, Any],
):
"""Actor-Critic model implemented in Rllib using Pytorch.
Model is composed of the following componenets
- backbone function b(x): maps observations to some embedding
- dense layers d(x):
- policy function p(x)
- value function v(x)
The backbone function maps environment observations defined by `obs_space` to
some embedding, which is further processed by the dense layers (if present).
This final embedding is then used by both the policy and value functions.
If specified in the model configuration, the policy and value functions may
optionally share weights.
Args:
obs_space (gym.spaces.Space): Policy observation space.
action_space (gym.spaces.Space): Policy action space.
num_outputs (int): Length of policy output vector.
model_config (Dict[str, Any]): Configuration passed to the
parent class TorchModelV2.
name (str): Policy name.
dense_layers (list, optional): Number of dense layers. Defaults to [ 512, ].
"""
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
Expand All @@ -70,12 +93,27 @@ def forward(
state: List[Tensor],
seq_len: Tensor,
) -> Tuple[Tensor, List[Tensor]]:
"""Pass observation through network.
Args:
input_dict (Dict[str, Tensor]): Input observation.
state (List[Tensor]): Current policy state (e.g., RNN state).
seq_len (Tensor): Current sequence length (e.g., sequence in episode).
Returns:
Tuple[Tensor, List[Tensor]]: Action produced by policy and new state.
"""
linear_in = self.forward_nets(input_dict["obs"], rnn_input=False)
self._features = self.linear(linear_in)
return self.policy(self._features), state

@override(ModelV2)
def value_function(self):
def value_function(self) -> torch.Tensor:
"""Get current value predicted by value function.
Returns:
torch.Tensor: Predicted value.
"""
return torch.reshape(self.value_branch(self._features), [-1])


Expand All @@ -92,6 +130,33 @@ def __init__(
rnn_type: str = "LSTM",
**network_args: Dict[str, Any],
):
"""Actor-Critic model with recurrent model implemented in Rllib using Pytorch.
Model is composed of the following
componenets
- backbone function b(x): maps observations to some embedding
- dense layers d(x):
- recurrent network r(x)
- policy function p(x)
- value function v(x)
The backbone function maps environment observations defined by `obs_space` to
some embedding, which is further processed by the dense layers (if present),
and then processed by the recurrent model. This final embedding is then used
by both the policy and value functions. If specified in the model configuration,
the policy and value function backbones may share weights.
Args:
obs_space (gym.spaces.Space): Policy observation space.
action_space (gym.spaces.Space): Policy action space.
num_outputs (int): Length of policy output vector.
model_config (Dict[str, Any]): Configuration passed to the
parent class TorchModelV2.
name (str): Policy name.
dense_layers (list, optional): Number of dense layers. Defaults to [ 512, ].
hidden_size (int, optional): RNN hidden size. Defaults to 512.
rnn_type (str, optional): RNN type. Type must be a module in torch.nn). Defaults to "LSTM".
"""
TorchRNN.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
Expand Down Expand Up @@ -120,6 +185,18 @@ def __init__(
def forward_rnn(
self, input: Tensor, state: List[Tensor], seq_len: Tensor
) -> Tuple[Tensor, List[Tensor]]:
"""Pass input observation through network.
Args:
input (Tensor): Input tensor of flattened observations.
Observations will be unflattened to the dimensions given
by `self.obs_space` in the `self.forward_nets` call.
state (List[Tensor]): List of states over the RNN sequence.
seq_len (Tensor): RNN sequence length.
Returns:
Tuple[Tensor, List[Tensor]]: Predicted action and updated states.
"""
base_features = self.forward_nets(input, rnn_input=True)
base_features = self.linear(base_features)

Expand Down Expand Up @@ -179,11 +256,11 @@ def get_initial_state(self) -> List[Tensor]:
return h

@override(ModelV2)
def value_function(self):
def value_function(self) -> torch.Tensor:
return torch.reshape(self.value_branch(self._features), [-1])

def get_weights(self):
def get_weights(self) -> None:
return None

def set_weights(self, weights):
def set_weights(self, weights: torch.Tensor) -> None:
pass

0 comments on commit 8df551f

Please sign in to comment.