From 26c9e80f42cad14fe2558a80a0b74cc52b4714a5 Mon Sep 17 00:00:00 2001 From: Ren Jiyuan <47732381+nighood@users.noreply.github.com> Date: Mon, 16 Oct 2023 17:18:55 +0800 Subject: [PATCH] polish(rjy): polish comments of qmix/pdqn/mavac (#736) * polish(rjy): polish comments of qmix/pdqn/mavac * polish(rjy): polish the comments of MAVAC * polish(rjy): fix the format of mavac * polish(rjy): polish comments in qmix model * polish(rjy): polish comments of pdqn model * polish(rjy): polish comments according to review * polish(rjy): fix style --- ding/model/template/mavac.py | 251 ++++++++++++++++++----------------- ding/model/template/pdqn.py | 76 +++++++---- ding/model/template/qmix.py | 133 ++++++++++--------- 3 files changed, 251 insertions(+), 209 deletions(-) diff --git a/ding/model/template/mavac.py b/ding/model/template/mavac.py index 47a1171dbf..cdd521f2b1 100644 --- a/ding/model/template/mavac.py +++ b/ding/model/template/mavac.py @@ -3,17 +3,20 @@ import torch.nn as nn from ding.utils import SequenceType, squeeze, MODEL_REGISTRY -from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \ - FCEncoder, ConvEncoder +from ..common import ReparameterizationHead, RegressionHead, DiscreteHead @MODEL_REGISTRY.register('mavac') class MAVAC(nn.Module): - r""" + """ Overview: - The MAVAC model. + The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for \ + multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and \ + continuous action space. The MAVAC is composed of four parts: ``actor_encoder``, ``critic_encoder``, \ + ``actor_head`` and ``critic_head``. Encoders are used to extract the feature from various observation. \ + Heads are used to predict corresponding value or action logit. Interfaces: - ``__init__``, ``forward``, ``compute_actor``, ``compute_critic`` + ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. """ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] @@ -33,26 +36,36 @@ def __init__( sigma_type: Optional[str] = 'independent', bound_type: Optional[str] = None, ) -> None: - r""" + """ Overview: - Init the VAC Model according to arguments. + Init the MAVAC Model according to arguments. Arguments: - - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space. - - action_shape (:obj:`Union[int, SequenceType]`): Action's space. - - share_encoder (:obj:`bool`): Whether share encoder. - - continuous (:obj:`bool`): Whether collect continuously. - - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder`` - - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``. - - actor_head_layer_num (:obj:`int`): - The num of layers used in the network to compute Q value output for actor's nn. - - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``. + - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent, \ + such as 8 or [4, 84, 84]. + - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space, such as 8 or [4, 84, 84]. + - action_shape (:obj:`Union[int, SequenceType]`): Action space shape for single agent, such as 6 \ + or [2, 3, 3]. + - agent_num (:obj:`int`): This parameter is temporarily reserved. This parameter may be required for \ + subsequent changes to the model + - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \ + to 256, it must match the last element of ``agent_obs_shape``. + - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action. + - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \ + to 512, it must match the last element of ``global_obs_shape``. - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for critic's nn. - - activation (:obj:`Optional[nn.Module]`): - The type of activation function to use in ``MLP`` the after ``layer_fn``, - if ``None`` then default set to ``nn.ReLU()`` - - norm_type (:obj:`Optional[str]`): - The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details` + - action_space (:obj:`Union[int, SequenceType]`): The type of different action spaces, including \ + ['discrete', 'continuous'], then will instantiate corresponding head, including ``DiscreteHead`` \ + and ``ReparameterizationHead``. + - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \ + ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. + - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ + ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']. + - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \ + ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in MAPPO, it defaults \ + to ``independent``, which means state-independent sigma parameters. + - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \ + to ``None``, which means no bound. """ super(MAVAC, self).__init__() agent_obs_shape: int = squeeze(agent_obs_shape) @@ -61,25 +74,6 @@ def __init__( self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape self.action_space = action_space # Encoder Type - if isinstance(agent_obs_shape, int) or len(agent_obs_shape) == 1: - encoder_cls = FCEncoder - elif len(agent_obs_shape) == 3: - encoder_cls = ConvEncoder - else: - raise RuntimeError( - "not support obs_shape for pre-defined encoder: {}, please customize your own DQN". - format(agent_obs_shape) - ) - if isinstance(global_obs_shape, int) or len(global_obs_shape) == 1: - global_encoder_cls = FCEncoder - elif len(global_obs_shape) == 3: - global_encoder_cls = ConvEncoder - else: - raise RuntimeError( - "not support obs_shape for pre-defined encoder: {}, please customize your own DQN". - format(global_obs_shape) - ) - # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder. # In SMAC task it can obviously improve the performance. # Users can change the model according to their own needs. @@ -126,77 +120,87 @@ def __init__( self.critic = nn.ModuleList(self.critic) def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict: - r""" + """ Overview: - Use encoded embedding tensor to predict output. - Parameter updates with VAC's MLPs forward setup. + MAVAC forward computation graph, input observation tensor to predict state value or action logit. \ + ``mode`` includes ``compute_actor``, ``compute_critic``, ``compute_actor_critic``. + Different ``mode`` will forward with different network modules to get different outputs and save \ + computation. Arguments: - Forward with ``'compute_actor'`` or ``'compute_critic'``: - - inputs (:obj:`torch.Tensor`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``. + - inputs (:obj:`Dict`): The input dict including observation and related info, \ + whose key-values vary from different ``mode``. + - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class. Returns: - - outputs (:obj:`Dict`): - Run with encoder and head. + - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph, whose key-values vary from \ + different ``mode``. - Forward with ``'compute_actor'``, Necessary Keys: - - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. - - Forward with ``'compute_critic'``, Necessary Keys: - - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. - Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size`` - - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` - - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. - - Actor Examples: - >>> model = VAC(64,128) - >>> inputs = torch.randn(4, 64) + Examples (Actor): + >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) + >>> inputs = { + 'agent_state': torch.randn(10, 8, 64), + 'global_state': torch.randn(10, 8, 128), + 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) + } >>> actor_outputs = model(inputs,'compute_actor') - >>> assert actor_outputs['logit'].shape == torch.Size([4, 128]) + >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) - Critic Examples: - >>> model = VAC(64,64) - >>> inputs = torch.randn(4, 64) + Examples (Critic): + >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) + >>> inputs = { + 'agent_state': torch.randn(10, 8, 64), + 'global_state': torch.randn(10, 8, 128), + 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) + } >>> critic_outputs = model(inputs,'compute_critic') - >>> critic_outputs['value'] - tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + >>> assert actor_outputs['value'].shape == torch.Size([10, 8]) - Actor-Critic Examples: - >>> model = VAC(64,64) - >>> inputs = torch.randn(4, 64) + Examples (Actor-Critic): + >>> model = MAVAC(64, 64) + >>> inputs = { + 'agent_state': torch.randn(10, 8, 64), + 'global_state': torch.randn(10, 8, 128), + 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) + } >>> outputs = model(inputs,'compute_actor_critic') - >>> outputs['value'] - tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) - >>> assert outputs['logit'].shape == torch.Size([4, 64]) + >>> assert outputs['value'].shape == torch.Size([10, 8, 14]) + >>> assert outputs['logit'].shape == torch.Size([10, 8]) """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) - def compute_actor(self, x: torch.Tensor) -> Dict: - r""" + def compute_actor(self, x: Dict) -> Dict: + """ Overview: - Execute parameter updates with ``'compute_actor'`` mode - Use encoded embedding tensor to predict output. + MAVAC forward computation graph for actor part, \ + predicting action logit with agent observation tensor in ``x``. Arguments: - - inputs (:obj:`torch.Tensor`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - ``hidden_size = actor_head_hidden_size`` + - x (:obj:`Dict`): Input data dict with keys ['agent_state', 'action_mask'(optional)]. + - agent_state: (:obj:`torch.Tensor`): Each agent local state(obs). + - action_mask(optional): (:obj:`torch.Tensor`): When ``action_space`` is discrete, action_mask needs \ + to be provided to mask illegal actions. Returns: - outputs (:obj:`Dict`): - Run with encoder and head. - + The output dict of MAVAC's forward computation graph for actor, including ``logit``. ReturnsKeys: - - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. + - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \ + the same dimension real-value ranged tensor of possible action choices, and for continuous action \ + space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \ + same as the number of continuous actions. Shapes: - - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` + - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ + and M is ``agent_num``. Examples: - >>> model = VAC(64,64) - >>> inputs = torch.randn(4, 64) + >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) + >>> inputs = { + 'agent_state': torch.randn(10, 8, 64), + 'global_state': torch.randn(10, 8, 128), + 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) + } >>> actor_outputs = model(inputs,'compute_actor') - >>> assert actor_outputs['action'].shape == torch.Size([4, 64]) + >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14]) + """ if self.action_space == 'discrete': action_mask = x['action_mask'] @@ -213,29 +217,30 @@ def compute_actor(self, x: torch.Tensor) -> Dict: return {'logit': logit} def compute_critic(self, x: Dict) -> Dict: - r""" + """ Overview: - Execute parameter updates with ``'compute_critic'`` mode - Use encoded embedding tensor to predict output. + MAVAC forward computation graph for critic part. \ + Predict state value with global observation tensor in ``x``. Arguments: - - inputs (:obj:`Dict`): - The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``. - ``hidden_size = critic_head_hidden_size`` + - x (:obj:`Dict`): Input data dict with keys ['global_state']. + - global_state: (:obj:`torch.Tensor`): Global state(obs). Returns: - - outputs (:obj:`Dict`): - Run with encoder and head. - - Necessary Keys: - - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. + - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for critic, \ + including ``value``. + ReturnsKeys: + - value (:obj:`torch.Tensor`): The predicted state value tensor. Shapes: - - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``agent_num``. Examples: - >>> model = VAC(64,64) - >>> inputs = torch.randn(4, 64) + >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14) + >>> inputs = { + 'agent_state': torch.randn(10, 8, 64), + 'global_state': torch.randn(10, 8, 128), + 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) + } >>> critic_outputs = model(inputs,'compute_critic') - >>> critic_outputs['value'] - tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) + >>> assert critic_outputs['value'].shape == torch.Size([10, 8]) """ x = self.critic_encoder(x['global_state']) @@ -243,37 +248,33 @@ def compute_critic(self, x: Dict) -> Dict: return {'value': x['pred']} def compute_actor_critic(self, x: Dict) -> Dict: - r""" + """ Overview: - Execute parameter updates with ``'compute_actor_critic'`` mode - Use encoded embedding tensor to predict output. + MAVAC forward computation graph for both actor and critic part, input observation to predict action \ + logit and state value. Arguments: - - inputs (:obj:`torch.Tensor`): The encoded embedding tensor. - + - x (:obj:Dict): The input dict contains ``agent_state``, ``global_state`` and other related info. Returns: - - outputs (:obj:`Dict`): - Run with encoder and head. - + - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for both actor and critic, \ + including ``logit`` and ``value``. ReturnsKeys: - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``. - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size. Shapes: - - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape`` - - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size. + - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \ + and M is ``agent_num``. + - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``. Examples: - >>> model = VAC(64,64) - >>> inputs = torch.randn(4, 64) + >>> model = MAVAC(64, 64) + >>> inputs = { + 'agent_state': torch.randn(10, 8, 64), + 'global_state': torch.randn(10, 8, 128), + 'action_mask': torch.randint(0, 2, size=(10, 8, 14)) + } >>> outputs = model(inputs,'compute_actor_critic') - >>> outputs['value'] - tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=) - >>> assert outputs['logit'].shape == torch.Size([4, 64]) - - - .. note:: - ``compute_actor_critic`` interface aims to save computation when shares encoder. - Returning the combination dictionry. - + >>> assert outputs['value'].shape == torch.Size([10, 8]) + >>> assert outputs['logit'].shape == torch.Size([10, 8, 14]) """ logit = self.compute_actor(x)['logit'] value = self.compute_critic(x)['value'] diff --git a/ding/model/template/pdqn.py b/ding/model/template/pdqn.py index 23b112364c..ec94cb3fe1 100644 --- a/ding/model/template/pdqn.py +++ b/ding/model/template/pdqn.py @@ -11,6 +11,17 @@ @MODEL_REGISTRY.register('pdqn') class PDQN(nn.Module): + """ + Overview: + The neural network and computation graph of PDQN(https://arxiv.org/abs/1810.06394v1) and \ + MPDQN(https://arxiv.org/abs/1905.04388) algorithms for parameterized action space. \ + This model supports parameterized action space with discrete ``action_type`` and continuous ``action_arg``. \ + In principle, PDQN consists of x network (continuous action parameter network) and Q network (discrete \ + action type network). But for simplicity, the code is split into ``encoder`` and ``actor_head``, which \ + contain the encoder and head of the above two networks respectively. + Interface: + ``__init__``, ``forward``, ``compute_discrete``, ``compute_continuous``. + """ mode = ['compute_discrete', 'compute_continuous'] def __init__( @@ -26,7 +37,7 @@ def __init__( multi_pass: Optional[bool] = False, action_mask: Optional[list] = None ) -> None: - r""" + """ Overview: Init the PDQN (encoder + head) Model according to input arguments. Arguments: @@ -37,17 +48,17 @@ def __init__( the last element must match ``head_hidden_size``. - dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``. - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network. - - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output + - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output. - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \ - if ``None`` then default set it to ``nn.ReLU()`` + if ``None`` then default set it to ``nn.ReLU()``. - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \ ``ding.torch_utils.fc_block`` for more details. - multi_pass (:obj:`Optional[bool]`): Whether to use multi pass version. - - action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are - associated to each discrete action. For example, if there are 3 discrete action, - 4 continous action args, and the first discrete action associates with the first - continuous action args, the second discrete action associates with the second continuous - action args, and the third discrete action associates with the remaining 2 action args, + - action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are \ + associated to each discrete action. For example, if there are 3 discrete action, \ + 4 continous action args, and the first discrete action associates with the first \ + continuous action args, the second discrete action associates with the second continuous \ + action args, and the third discrete action associates with the remaining 2 action args, \ the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4. """ super(PDQN, self).__init__() @@ -120,33 +131,42 @@ def __init__( self.actor_head = nn.ModuleList([self.dis_head, self.cont_head]) # self.encoder = nn.ModuleList([self.dis_encoder, self.cont_encoder]) + # To speed up the training process, the X network and the Q network share the encoder for the state self.encoder = nn.ModuleList([self.cont_encoder, self.cont_encoder]) def forward(self, inputs: Union[torch.Tensor, Dict, EasyDict], mode: str) -> Dict: - r""" + """ Overview: PDQN forward computation graph, input observation tensor to predict q_value for \ - discrete actions and values for continuous action_args + discrete actions and values for continuous action_args. Arguments: - - inputs (:obj:`torch.Tensor`): Observation inputs + - inputs (:obj:`Union[torch.Tensor, Dict, EasyDict]`): Inputs including observation and \ + other info according to `mode`. - mode (:obj:`str`): Name of the forward mode. Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` + - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``. """ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) return getattr(self, mode)(inputs) def compute_continuous(self, inputs: torch.Tensor) -> Dict: - r""" + """ Overview: Use observation tensor to predict continuous action args. Arguments: - - inputs (:obj:`torch.Tensor`): Observation inputs - Shapes: - - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape`` + - inputs (:obj:`torch.Tensor`): Observation inputs. Returns: - - outputs (:obj:`Dict`): A dict with key 'action_args' - - 'action_args': the continuous action args + - outputs (:obj:`Dict`): A dict with key 'action_args'. + - 'action_args' (:obj:`torch.Tensor`): The continuous action args. + Shapes: + - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``. + - action_args (:obj:`torch.Tensor`): :math:`(B, M)`, where M is ``action_args_shape``. + Examples: + >>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )}) + >>> model = PDQN(4, act_shape) + >>> inputs = torch.randn(64, 4) + >>> outputs = model.forward(inputs, mode='compute_continuous') + >>> assert outputs['action_args'].shape == torch.Size([64, 5]) """ cont_x = self.encoder[1](inputs) # size (B, encoded_state_shape) action_args = self.actor_head[1](cont_x)['pred'] # size (B, action_args_shape) @@ -154,15 +174,25 @@ def compute_continuous(self, inputs: torch.Tensor) -> Dict: return outputs def compute_discrete(self, inputs: Union[Dict, EasyDict]) -> Dict: - r""" + """ Overview: Use observation tensor and continuous action args to predict discrete action types. Arguments: - - inputs (:obj:`torch.Tensor`): A dict with keys 'state', 'action_args' + - inputs (:obj:`Union[Dict, EasyDict]`): A dict with keys 'state', 'action_args'. + - state (:obj:`torch.Tensor`): Observation inputs. + - action_args (:obj:`torch.Tensor`): Action parameters are used to concatenate with the observation \ + and serve as input to the discrete action type network. Returns: - - outputs (:obj:`Dict`): A dict with keys 'logit', 'action_args' - - 'logit': the logit value for each discrete action, - - 'action_args': the continuous action args(same as the inputs['action_args']) for later usage + - outputs (:obj:`Dict`): A dict with keys 'logit', 'action_args'. + - 'logit': The logit value for each discrete action. + - 'action_args': The continuous action args(same as the inputs['action_args']) for later usage. + Examples: + >>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )}) + >>> model = PDQN(4, act_shape) + >>> inputs = {'state': torch.randn(64, 4), 'action_args': torch.randn(64, 5)} + >>> outputs = model.forward(inputs, mode='compute_discrete') + >>> assert outputs['logit'].shape == torch.Size([64, 3]) + >>> assert outputs['action_args'].shape == torch.Size([64, 5]) """ dis_x = self.encoder[0](inputs['state']) # size (B, encoded_state_shape) action_args = inputs['action_args'] # size (B, action_args_shape) diff --git a/ding/model/template/qmix.py b/ding/model/template/qmix.py index 483b48027c..68354e0cf7 100644 --- a/ding/model/template/qmix.py +++ b/ding/model/template/qmix.py @@ -11,20 +11,32 @@ class Mixer(nn.Module): """ Overview: - mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value + Mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value. \ + The weights (but not the biases) of the Mixer network are restricted to be non-negative and \ + produced by separate hypernetworks. Each hypernetwork takes the globle state s as input and generates \ + the weights of one layer of the Mixer network. Interface: - __init__, forward + ``__init__``, ``forward``. """ - def __init__(self, agent_num, state_dim, mixing_embed_dim, hypernet_embed=64, activation=nn.ReLU()): + def __init__( + self, + agent_num: int, + state_dim: int, + mixing_embed_dim: int, + hypernet_embed: int = 64, + activation: nn.Module = nn.ReLU() + ): """ Overview: - Initialize mixer network proposed in QMIX. + Initialize mixer network proposed in QMIX according to arguments. Each hypernetwork consists of \ + linear layers, followed by an absolute activation function, to ensure that the Mixer network weights are \ + non-negative. Arguments: - - agent_num (:obj:`int`): the number of agent - - state_dim(:obj:`int`): the dimension of global observation state - - mixing_embed_dim (:obj:`int`): the dimension of mixing state emdedding - - hypernet_embed (:obj:`int`): the dimension of hypernet emdedding, default to 64 + - agent_num (:obj:`int`): The number of agent, such as 8. + - state_dim(:obj:`int`): The dimension of global observation state, such as 16. + - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, such as 128. + - hypernet_embed (:obj:`int`): The dimension of hypernet emdedding, default to 64. - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). """ super(Mixer, self).__init__() @@ -41,7 +53,7 @@ def __init__(self, agent_num, state_dim, mixing_embed_dim, hypernet_embed=64, ac nn.Linear(self.state_dim, hypernet_embed), self.act, nn.Linear(hypernet_embed, self.embed_dim) ) - # State dependent bias for hidden layer + # state dependent bias for hidden layer self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) # V(s) instead of a bias for the last layers @@ -50,16 +62,17 @@ def __init__(self, agent_num, state_dim, mixing_embed_dim, hypernet_embed=64, ac def forward(self, agent_qs, states): """ Overview: - forward computation graph of pymarl mixer network + Forward computation graph of pymarl mixer network. Mix up the input independent q_value of each agent \ + to a total q_value with weights generated by hypernetwork according to global ``states``. Arguments: - - agent_qs (:obj:`torch.FloatTensor`): the independent q_value of each agent - - states (:obj:`torch.FloatTensor`): the emdedding vector of global state + - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent. + - states (:obj:`torch.FloatTensor`): The emdedding vector of global state. Returns: - - q_tot (:obj:`torch.FloatTensor`): the total mixed q_value + - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value. Shapes: - - agent_qs (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is agent_num - - states (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is embedding_size - - q_tot (:obj:`torch.FloatTensor`): :math:`(B, )` + - agent_qs (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is agent_num. + - states (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is embedding_size. + - q_tot (:obj:`torch.FloatTensor`): :math:`(B, )`. """ bs = agent_qs.shape[:-1] states = states.reshape(-1, self.state_dim) @@ -86,9 +99,12 @@ def forward(self, agent_qs, states): class QMix(nn.Module): """ Overview: - QMIX network + The neural network and computation graph of algorithms related to QMIX(https://arxiv.org/abs/1803.11485). \ + The QMIX is composed of two parts: agent Q network and mixer(optional). The QMIX paper mentions that all \ + agents share local Q network parameters, so only one Q network is initialized here. Then use summation or \ + Mixer network to process the local Q according to the ``mixer`` settings to obtain the global Q. Interface: - __init__, forward, _setup_global_encoder + ``__init__``, ``forward``. """ def __init__( @@ -105,17 +121,22 @@ def __init__( ) -> None: """ Overview: - Initialize QMIX neural network, i.e. agent Q network and mixer. + Initialize QMIX neural network according to arguments, i.e. agent Q network and mixer. Arguments: - - agent_num (:obj:`int`): the number of agent - - obs_shape (:obj:`int`): the dimension of each agent's observation state - - global_obs_shape (:obj:`int`): the dimension of global observation state - - action_shape (:obj:`int`): the dimension of action shape - - hidden_size_list (:obj:`list`): the list of hidden size - - mixer (:obj:`bool`): use mixer net or not, default to True - - lstm_type (:obj:`str`): use lstm or gru, default to gru - - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU(). - - dueling (:obj:`bool`): use dueling head or not, default to False. + - agent_num (:obj:`int`): The number of agent, such as 8. + - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8 or [4, 84, 84]. + - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8 or [4, 84, 84]. + - action_shape (:obj:`int`): The dimension of action shape, such as 6 or [2, 3, 3]. + - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \ + the last element must match mixer's ``mixing_embed_dim``. + - mixer (:obj:`bool`): Use mixer net or not, default to True. If it is false, \ + the final local Q is added to obtain the global Q. + - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \ + ['normal', 'pytorch', 'gru'], default to gru. + - activation (:obj:`nn.Module`): The type of activation function to use in ``MLP`` the after \ + ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``. + - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \ + default to False. """ super(QMix, self).__init__() self._act = activation @@ -131,30 +152,32 @@ def __init__( def forward(self, data: dict, single_step: bool = True) -> dict: """ Overview: - forward computation graph of qmix network + QMIX forward computation graph, input dict including time series observation and related data to predict \ + total q_value and each agent q_value. Arguments: - - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action'] - - agent_state (:obj:`torch.Tensor`): each agent local state(obs) - - global_state (:obj:`torch.Tensor`): global state(obs) - - prev_state (:obj:`list`): previous rnn state - - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\ - calculate ``agent_q_act`` - - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\ - remove it after forward + - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action']. + - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents. + - global_state (:obj:`torch.Tensor`): Time series global observation data. + - prev_state (:obj:`list`): Previous rnn state for ``q_network``. + - action (:obj:`torch.Tensor` or None): The actions of each agent given outside the function. \ + If action is None, use argmax q_value index as action to calculate ``agent_q_act``. + - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\ + remove it after forward. Returns: - - ret (:obj:`dict`): output data dict with keys [``total_q``, ``logit``, ``next_state``] - - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network - - agent_q (:obj:`torch.Tensor`): each agent q_value - - next_state (:obj:`list`): next rnn state + - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``]. + ReturnsKeys: + - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network. + - agent_q (:obj:`torch.Tensor`): Each agent q_value. + - next_state (:obj:`list`): Next rnn state for ``q_network``. Shapes: - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\ - A is agent_num, N is obs_shape - - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape - - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A - - action (:obj:`torch.Tensor`): :math:`(T, B, A)` - - total_q (:obj:`torch.Tensor`): :math:`(T, B)` - - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape - - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A + A is agent_num, N is obs_shape. + - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape. + - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A. + - action (:obj:`torch.Tensor`): :math:`(T, B, A)`. + - total_q (:obj:`torch.Tensor`): :math:`(T, B)`. + - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape. + - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A. """ agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[ 'prev_state'] @@ -172,7 +195,7 @@ def forward(self, data: dict, single_step: bool = True) -> dict: next_state, _ = list_split(next_state, step=A) agent_q = agent_q.reshape(T, B, A, -1) if action is None: - # For target forward process + # for target forward process if len(data['obs']['action_mask'].shape) == 3: action_mask = data['obs']['action_mask'].unsqueeze(0) else: @@ -194,15 +217,3 @@ def forward(self, data: dict, single_step: bool = True) -> dict: 'next_state': next_state, 'action_mask': data['obs']['action_mask'] } - - def _setup_global_encoder(self, global_obs_shape: int, embedding_size: int) -> torch.nn.Module: - """ - Overview: - Used to encoder global observation. - Arguments: - - global_obs_shape (:obj:`int`): the dimension of global observation state - - embedding_size (:obj:`int`): the dimension of state emdedding - Return: - - outputs (:obj:`torch.nn.Module`): Global observation encoding network - """ - return MLP(global_obs_shape, embedding_size, embedding_size, 2, activation=self._act)