Skip to content

Commit

Permalink
polish(rjy): polish comments in pg model
Browse files Browse the repository at this point in the history
  • Loading branch information
nighood committed Oct 19, 2023
1 parent 875e9b5 commit a9981fa
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ding/model/template/ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class NGU(nn.Module):
q_leaning.py. The implementation mentioned in the original paper is 'adapt the R2D2 agent that uses the \
dueling network architecture with an LSTM layer after a convolutional neural network'. The NGU network \
includes encoder, LSTM core(rnn) and head.
Interface:
Interface:
``__init__``, ``forward``.
"""

Expand Down
44 changes: 44 additions & 0 deletions ding/model/template/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

@MODEL_REGISTRY.register('pg')
class PG(nn.Module):
"""
Overview:
The neural network and computation graph of algorithms related to Policy Gradient(PG) \
(https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf). \
The PG model is composed of two parts: encoder and head. Encoders are used to extract the feature \
from various observation. Heads are used to predict corresponding action logit. \
Interface:
``__init__``, ``forward``.
"""

def __init__(
self,
Expand All @@ -23,6 +32,31 @@ def __init__(
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
) -> None:
"""
Overview:
Initialize the PG model according to corresponding input arguments.
Arguments:
- obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
- action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
- action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous'], \
then will instantiate corresponding head, including ``DiscreteHead`` and ``ReparameterizationHead``.
- encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
the last element must match ``head_hidden_size``.
- head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``head`` network, defaults \
to None, it must match the last element of ``encoder_hidden_size_list``.
- head_layer_num (:obj:`int`): The num of layers used in the ``head`` network to compute action.
- activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
if ``None`` then default set it to ``nn.ReLU()``.
- norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
Examples:
>>> model = PG((4, 84, 84), 5)
>>> inputs = torch.randn(8, 4, 84, 84)
>>> outputs = model(inputs)
>>> assert isinstance(outputs, dict)
>>> assert outputs['logit'].shape == (8, 5)
>>> assert outputs['dist'].sample().shape == (8, )
"""
super(PG, self).__init__()
# For compatibility: 1, (1, ), [4, 32, 32]
obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
Expand Down Expand Up @@ -57,6 +91,16 @@ def __init__(
raise KeyError("not support action space: {}".format(self.action_space))

def forward(self, x: torch.Tensor) -> Dict:
"""
Overview:
PG forward computation graph, input observation tensor to predict policy distribution.
Arguments:
- x (:obj:`torch.Tensor`): The input observation tensor data.
Returns:
- outputs (:obj:`torch.distributions`): The output policy distribution. If action space is \
discrete, the output is Categorical distribution; if action space is continuous, the output is Normal \
distribution.
"""
x = self.encoder(x)
x = self.head(x)
if self.action_space == 'discrete':
Expand Down

0 comments on commit a9981fa

Please sign in to comment.